Exemple #1
0
 def build(self, _):
     for column in self._feature_columns:
         with variable_scope._pure_variable_scope(  # pylint: disable=protected-access
                 self.name,
                 partitioner=self._partitioner):
             with variable_scope._pure_variable_scope(  # pylint: disable=protected-access
                     feature_column_v2.
                     _sanitize_column_name_for_variable_scope(  # pylint: disable=protected-access
                         column.name)):
                 column.create_state(self._state_manager)
     super(_BaseFeaturesLayer, self).build(None)
Exemple #2
0
  def wrapper(obj, *args, **kwargs):
    def default_context_manager(reuse=None):
      variable_scope = obj.variable_scope
      return tf.variable_scope(variable_scope, reuse=reuse)

    variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
                                             default_context_manager)
    graph = tf.get_default_graph()

    # Temporarily enter the variable scope to capture it
    with variable_scope_context_manager() as tmp_variable_scope:
      variable_scope = tmp_variable_scope

    with variable_scope_ops._pure_variable_scope(
        variable_scope, reuse=tf.AUTO_REUSE) as pure_variable_scope:

      name_scope = variable_scope.original_name_scope
      if name_scope[-1] != "/":
        name_scope += "/"

      with tf.name_scope(name_scope):
        sub_scope = snt_util.to_snake_case(method.__name__)
        with tf.name_scope(sub_scope) as scope:
          out_ops = method(obj, *args, **kwargs)
          return out_ops
Exemple #3
0
  def wrapper(obj, *args, **kwargs):
    def default_context_manager(reuse=None):
      variable_scope = obj.variable_scope
      return tf.variable_scope(variable_scope, reuse=reuse)

    variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
                                             default_context_manager)
    graph = tf.get_default_graph()

    # Temporarily enter the variable scope to capture it
    with variable_scope_context_manager() as tmp_variable_scope:
      variable_scope = tmp_variable_scope

    with variable_scope_ops._pure_variable_scope(
        variable_scope, reuse=tf.AUTO_REUSE) as pure_variable_scope:

      name_scope = variable_scope.original_name_scope
      if name_scope[-1] != "/":
        name_scope += "/"

      with tf.name_scope(name_scope):
        sub_scope = snt_util.to_snake_case(method.__name__)
        with tf.name_scope(sub_scope) as scope:
          out_ops = method(obj, *args, **kwargs)
          return out_ops
Exemple #4
0
    def __init__(self,
                 name,
                 func,
                 create_scope_now=False,
                 unique_name=None,
                 custom_getter=None,
                 create_graph_function=False):
        """Creates a template for the given function.

    Args:
      name: A name for the scope created by this template. The
        name will be made unique by appending `_N` to the it (see how
        `tf.variable_scope` treats the `default_name` for details).
      func: The function to apply each time.
      create_scope_now: Whether to create the scope at Template construction
        time, rather than first call. Defaults to false. Creating the scope at
        construction time may be more convenient if the template is to passed
        through much lower level code, and you want to be sure of the scope
        name without knowing exactly where it will be first called. If set to
        True, the scope will be created in the constructor, and all subsequent
        times in `__call__`, leading to a trailing numeral being added to the
        names of all created Tensors. If set to False, the scope will be created
        at the first call location.
      unique_name: When used, it overrides `name` and is not made unique. If a
        template of the same scope/unique_name already exists and reuse is
        false, an error is raised. Defaults to None.
      custom_getter: optional custom getter to pass to `variable_scope()`
      create_graph_function: When True, the first invocation of the template
        will execute `func` as is, to allow for variable creation; however, the
        second invocation and every invocation thereafter will execute `func` as
        a graph function. Enabling this flag gives the caller access to
        graph-function semantics, i.e., accesses to variables are totally
        ordered and side-effecting ops are not pruned.


    Raises:
      ValueError: if the name is None.
    """
        self._func = func
        self._stacktrace = traceback.format_stack()[:-2]
        self._name = name
        self._unique_name = unique_name
        self._custom_getter = custom_getter
        if name is None:
            raise ValueError("name cannot be None.")
        if create_scope_now:
            with variable_scope._pure_variable_scope(  # pylint:disable=protected-access
                (self._unique_name
                 or variable_scope._get_unique_variable_scope(self._name)),  # pylint:disable=protected-access
                    custom_getter=self._custom_getter) as vs:
                self._variable_scope = vs
        else:
            self._variable_scope = None
        # This variable keeps track of whether the template has been called yet,
        # which is not the same as whether the scope has been created.
        self._variables_created = False
        self._create_graph_function = create_graph_function
Exemple #5
0
  def __init__(self, name, func, create_scope_now=False, unique_name=None,
               custom_getter=None, create_graph_function=False):
    """Creates a template for the given function.

    Args:
      name: A name for the scope created by this template. The
        name will be made unique by appending `_N` to the it (see how
        `tf.variable_scope` treats the `default_name` for details).
      func: The function to apply each time.
      create_scope_now: Whether to create the scope at Template construction
        time, rather than first call. Defaults to false. Creating the scope at
        construction time may be more convenient if the template is to passed
        through much lower level code, and you want to be sure of the scope
        name without knowing exactly where it will be first called. If set to
        True, the scope will be created in the constructor, and all subsequent
        times in `__call__`, leading to a trailing numeral being added to the
        names of all created Tensors. If set to False, the scope will be created
        at the first call location.
      unique_name: When used, it overrides `name` and is not made unique. If a
        template of the same scope/unique_name already exists and reuse is
        false, an error is raised. Defaults to None.
      custom_getter: optional custom getter to pass to `variable_scope()`
      create_graph_function: When True, `func` will be executed as a graph
        function. Enabling this flag gives the caller access to graph-function
        semantics, i.e., accesses to variables are totally ordered and
        side-effecting ops are not pruned.

    Raises:
      ValueError: if `name` is None.
    """
    if create_graph_function:
      self._func = function.defun(func)
    else:
      self._func = func
    self._stacktrace = traceback.format_stack()[:-2]
    self._name = name
    self._unique_name = unique_name
    self._custom_getter = custom_getter
    if name is None:
      raise ValueError("name cannot be None.")
    if create_scope_now:
      with variable_scope._pure_variable_scope(  # pylint:disable=protected-access
          (self._unique_name or
           variable_scope._get_unique_variable_scope(self._name)),  # pylint:disable=protected-access
          custom_getter=self._custom_getter) as vs:
        self._variable_scope = vs
    else:
      self._variable_scope = None
    # This variable keeps track of whether the template has been called to
    # completion, which is not the same as whether the scope has been created.
    self._variables_created = False
    # `MirroredStrategy` builds the graph with multiple threads. If a
    # `merge_call` happens within a template, multiple calls may be in progress
    # simultaneously. This variable keeps track of whether any call of the
    # template has started.
    self._first_call = True
def reopen_variable_scope(var_scope,
                          reuse=None,
                          initializer=None,
                          regularizer=None,
                          caching_device=None,
                          partitioner=None,
                          custom_getter=None,
                          dtype=tf.float32):
    """Reopen the specified `var_scope` and its `original_name_scope`.

    `tf.variable_scope` will not open the original name scope, even if a
    stored `tf.VariableScope` instance is specified.  This method thus
    allows to open exactly the same name scope as the original one.

    Parameters
    ----------
    var_scope : tf.VariableScope
        The variable scope instance.

    reuse : None | bool
        Whether or not to reuse the variables in opened scope?

    initializer, regularizer, caching_device, partitioner, custom_getter, dtype
        Other parameters for opening the variable scope.
    """
    if not isinstance(var_scope, tf.VariableScope):
        raise TypeError('`var_scope` is expected to be an instance of '
                        '`tf.VariableScope`.')
    old_name_scope = var_scope.original_name_scope

    with variable_scope_ops._pure_variable_scope(var_scope,
                                                 reuse=reuse,
                                                 initializer=initializer,
                                                 regularizer=regularizer,
                                                 caching_device=caching_device,
                                                 partitioner=partitioner,
                                                 custom_getter=custom_getter,
                                                 old_name_scope=old_name_scope,
                                                 dtype=dtype) as vs:
        name_scope = old_name_scope
        if name_scope and not name_scope.endswith('/'):
            name_scope += '/'

        with tf.name_scope(name_scope):
            yield vs
Exemple #7
0
  def __init__(self, name, func, create_scope_now=False, unique_name=None,
               custom_getter=None):
    """Creates a template for the given function.

    Args:
      name: A name for the scope created by this template. The
        name will be made unique by appending `_N` to the it (see how
        `tf.variable_scope` treats the `default_name` for details).
      func: The function to apply each time.
      create_scope_now: Whether to create the scope at Template construction
        time, rather than first call. Defaults to false. Creating the scope at
        construction time may be more convenient if the template is to passed
        through much lower level code, and you want to be sure of the scope
        name without knowing exactly where it will be first called. If set to
        True, the scope will be created in the constructor, and all subsequent
        times in __call__, leading to a trailing numeral being added to the
        names of all created Tensors. If set to False, the scope will be created
        at the first call location.
      unique_name: When used, it overrides name_ and is not made unique. If a
        template of the same scope/unique_name already exists and reuse is
        false, an error is raised. Defaults to None.
      custom_getter: optional custom getter to pass to variable_scope()

    Raises:
      ValueError: if the name is None.
    """
    self._func = func
    self._stacktrace = traceback.format_stack()[:-2]
    self._name = name
    self._unique_name = unique_name
    self._custom_getter = custom_getter
    if name is None:
      raise ValueError("name cannot be None.")
    if create_scope_now:
      with variable_scope._pure_variable_scope(  # pylint:disable=protected-access
          (self._unique_name or
           variable_scope._get_unique_variable_scope(self._name)),  # pylint:disable=protected-access
          custom_getter=self._custom_getter) as vs:
        self._variable_scope = vs
    else:
      self._variable_scope = None
    # This variable keeps track of whether the template has been called yet,
    # which is not the same as whether the scope has been created.
    self._variables_created = False
Exemple #8
0
def root_variable_scope(**kwargs):
    """
    Open the root variable scope and its name scope.

    Args:
        **kwargs: Named arguments for opening the root variable scope.
    """
    # `tf.variable_scope` does not support opening the root variable scope
    # from empty name.  It always prepend the name of current variable scope
    # to the front of opened variable scope.  So we get the current scope,
    # and pretend it to be the root scope.
    scope = tf.get_variable_scope()
    old_name = scope.name
    try:
        scope._name = ''
        with variable_scope_ops._pure_variable_scope('', **kwargs) as vs:
            scope._name = old_name
            with tf.name_scope(None):
                yield vs
    finally:
        scope._name = old_name
def root_variable_scope(reuse=None,
                        initializer=None,
                        regularizer=None,
                        caching_device=None,
                        partitioner=None,
                        custom_getter=None,
                        dtype=tf.float32):
    """Open the root name and variable scope.

    Parameters
    ----------
    reuse : None | bool
        Whether or not to reuse the variables in opened scope?

    initializer, regularizer, caching_device, partitioner, custom_getter, dtype
        Other parameters for opening the variable scope.
    """
    # `tf.variable_scope` does not support opening the root variable scope
    # from empty name.  It always prepend the name of current variable scope
    # to the front of opened variable scope.  So we get the current scope,
    # and pretend it to be the root scope.
    scope = tf.get_variable_scope()
    old_name = scope.name
    try:
        scope._name = ''
        with variable_scope_ops._pure_variable_scope(
                '',
                reuse=reuse,
                initializer=initializer,
                regularizer=regularizer,
                caching_device=caching_device,
                partitioner=partitioner,
                custom_getter=custom_getter,
                old_name_scope='',
                dtype=dtype) as vs:
            scope._name = old_name
            with tf.name_scope(None):
                yield vs
    finally:
        scope._name = old_name
Exemple #10
0
def reopen_variable_scope(var_scope, **kwargs):
    """
    Reopen the specified `var_scope` and its original name scope.

    Unlike :func:`tf.variable_scope`, which does not open the original name
    scope even if a stored :class:`tf.VariableScope` instance is specified,
    this method opens exactly the same name scope as the original one.

    Args:
        var_scope (tf.VariableScope): The variable scope instance.
        **kwargs: Named arguments for opening the variable scope.
    """
    if not isinstance(var_scope, tf.VariableScope):
        raise TypeError(
            '`var_scope` must be an instance of `tf.VariableScope`')
    old_name_scope = var_scope.original_name_scope
    with variable_scope_ops._pure_variable_scope(var_scope, **kwargs) as vs:
        name_scope = old_name_scope
        if name_scope and not name_scope.endswith('/'):
            name_scope += '/'  # pragma: no cover

        with tf.name_scope(name_scope):
            yield vs
Exemple #11
0
    def call_method(method, obj, args, kwargs):
        """Calls `method` with a variable scope whose reuse flag is set correctly.

    The first time the wrapper is called it creates a
    `(tf.Graph, tf.VariableScope)` key and checks it for membership in
    `initialized_variable_scopes`. The check is `False` if and only if this is
    the first time the wrapper has been called with the key, otherwise the
    check is `True`. The result of this check is used as the `reuse` flag for
    entering the provided variable scope before calling `method`.

    Here are two examples of how to use the reuse_variables decorator.

    1. Decorate an arbitrary instance method with a `variable_scope` attribute:

      ```python
      class Reusable(object):

        def __init__(self, name):
          with tf.variable_scope(None, default_name=name) as vs:
            self.variable_scope = vs

        @snt.reuse_variables
        def add_a(self, input_tensor):
          a = tf.get_variable("a", shape=input_tensor.get_shape())
          return a + input_tensor

      obj = Reusable("reusable")
      x = tf.constant(5.0)
      out1 = obj.add_a(x)
      out2 = obj.add_a(x)
      # out1 == out2
      ```

    2. Decorating a snt.AbstractModule instance method:

      ```python
      class ReusableModule(snt.AbstractModule):

        @snt.reuse_variables
        def add_a(self, input_tensor):
          a = tf.get_variable("a", shape=input_tensor.get_shape())
          return a + input_tensor

        # We don't need @snt.reuse_variables here because build is
        wrapped by # `tf.make_template` inside `snt.AbstractModule`.
        def _build(self, input_tensor):
          b = tf.get_variable("b", shape=input_tensor.get_shape())
          return b + self.add_a(input_tensor)

      obj = Reusable("reusable")
      x = tf.constant(5.0)
      out1 = obj(x)
      out2 = obj(x)
      # out1 == out2
      ```

    Args:
      method: The method to wrap.
      obj: The object instance passed to the wrapped method.
      args: The positional arguments (Tensors) passed to the wrapped method.
      kwargs: The keyword arguments passed to the wrapped method.

    Returns:
      Output of the wrapped method.

    Raises:
      ValueError: If no variable scope is provided or if `method` is a method
                  and a variable_scope keyword argument is also provided.
    """

        # If @reuse_variables is combined with @property, obj is passed in args
        # and method is still unbound at this stage.
        if obj is None:
            obj = args[0]

        def default_context_manager(reuse=None):
            variable_scope = obj.variable_scope
            return tf.variable_scope(variable_scope, reuse=reuse)

        variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
                                                 default_context_manager)

        with ops.init_scope():
            # We need `init_scope` incase we're running inside a defun. In that case
            # what we want is information about where the function will be called not
            # where the function is being built.
            graph = tf.get_default_graph()
            will_call_in_eager_context = tf.executing_eagerly()

        if will_call_in_eager_context:
            initialized_variable_scopes = initialized_variable_scopes_eager
        else:
            if graph not in initialized_variable_scopes_graph:
                initialized_variable_scopes_graph[graph] = set()
            initialized_variable_scopes = initialized_variable_scopes_graph[
                graph]

        # Temporarily enter the variable scope to capture it
        with variable_scope_context_manager() as tmp_variable_scope:
            variable_scope = tmp_variable_scope

        reuse = variable_scope.name in initialized_variable_scopes

        # Enter the pure variable scope with reuse correctly set
        with variable_scope_ops._pure_variable_scope(  # pylint:disable=protected-access
                variable_scope, reuse=reuse) as pure_variable_scope:
            current_name_scope = tf.get_default_graph().get_name_scope()
            # Force tf.name_scope to treat current_name_scope as an "absolute" scope
            # so we can re-enter it.
            if current_name_scope and current_name_scope[-1] != "/":
                current_name_scope += "/"
            with tf.name_scope(current_name_scope):
                module_name = pure_variable_scope.name
                method_name = to_snake_case(method.__name__)
                method_name_scope = "{}/{}".format(module_name, method_name)
                with tf.name_scope(method_name_scope) as scope:
                    if hasattr(obj, "_capture_variables"):
                        with obj._capture_variables():  # pylint: disable=protected-access
                            out_ops = method(*args, **kwargs)
                    else:
                        out_ops = method(*args, **kwargs)
            initialized_variable_scopes.add(pure_variable_scope.name)
            try:
                # If `obj` is a Sonnet module, let it know it's been connected
                # to the TF graph.
                obj._is_connected = True  # pylint: disable=protected-access
                if not tf.executing_eagerly():
                    obj._add_connected_subgraph(  # pylint: disable=protected-access
                        method, out_ops, scope, *args, **kwargs)
            except AttributeError:
                pass
        return out_ops
Exemple #12
0
    def wrapper(obj, *args, **kwargs):
        """Calls `method` with a variable scope whose reuse flag is set correctly.

    The first time the wrapper is called it creates a
    `(tf.Graph, tf.VariableScope)` key and checks it for membership in
    `initialized_variable_scopes`. The check is `False` if and only if this is
    the first time the wrapper has been called with the key, otherwise the
    check is `True`. The result of this check is used as the `reuse` flag for
    entering the provided variable scope before calling `method`.

    Here are two examples of how to use the reuse_variables decorator.

    1. Decorate an arbitrary instance method with a `variable_scope` attribute:

      ```python
      class Reusable(object):

        def __init__(self, name):
          with tf.variable_scope(name) as vs:
            self.variable_scope = vs

        @snt.reuse_variables
        def add_a(self, input_tensor):
          a = tf.get_variable("a", shape=input_tensor.get_shape())
          return a + input_tensor

      obj = Reusable("reusable")
      x = tf.constant(5.0)
      out1 = obj.add_a(x)
      out2 = obj.add_a(x)
      # out1 == out2
      ```

    2. Decorating a snt.AbstractModule instance method:

      ```python
      class ReusableModule(snt.AbstractModule):

        @snt.reuse_variables
        def add_a(self, input_tensor):
          a = tf.get_variable("a", shape=input_tensor.get_shape())
          return a + input_tensor

        # We don't need @snt.reuse_variables here because build is
        wrapped by # `tf.make_template` inside `snt.AbstractModule`.
        def _build(self, input_tensor):
          b = tf.get_variable("b", shape=input_tensor.get_shape())
          return b + self.add_a(input_tensor)

      obj = Reusable("reusable")
      x = tf.constant(5.0)
      out1 = obj(x)
      out2 = obj(x)
      # out1 == out2
      ```

    Args:
      obj: The object instance passed to the wrapped method.
      *args: The positional arguments (Tensors) passed to the wrapped method.
      **kwargs: The keyword arguments passed to the wrapped method.

    Returns:
      Output of the wrapped method.

    Raises:
      ValueError: If no variable scope is provided or if `method` is a method
                  and a variable_scope keyword argument is also provided.
    """
        def default_context_manager(reuse=None):
            variable_scope = obj.variable_scope
            return tf.variable_scope(variable_scope, reuse=reuse)

        variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
                                                 default_context_manager)

        graph = tf.get_default_graph()
        if graph not in initialized_variable_scopes:
            initialized_variable_scopes[graph] = set()
        initialized_variable_scopes_for_graph = initialized_variable_scopes[
            graph]

        # Temporarily enter the variable scope to capture it
        with variable_scope_context_manager() as tmp_variable_scope:
            variable_scope = tmp_variable_scope

        reuse = variable_scope.name in initialized_variable_scopes_for_graph

        # Enter the pure variable scope with reuse correctly set
        with variable_scope_ops._pure_variable_scope(  # pylint:disable=protected-access
                variable_scope, reuse=reuse) as pure_variable_scope:
            # Force tf.name_scope to treat variable_scope.original_name_scope as
            # an "absolute" scope name so we can re-enter it.
            name_scope = variable_scope.original_name_scope
            if name_scope[-1] != "/":
                name_scope += "/"
            with tf.name_scope(name_scope):
                sub_scope = to_snake_case(method.__name__)
                with tf.name_scope(sub_scope):
                    out_ops = method(obj, *args, **kwargs)
                    initialized_variable_scopes_for_graph.add(
                        pure_variable_scope.name)
                    return out_ops
Exemple #13
0
  def wrapper(obj, *args, **kwargs):
    """Calls `method` with a variable scope whose reuse flag is set correctly.

    The first time the wrapper is called it creates a
    `(tf.Graph, tf.VariableScope)` key and checks it for membership in
    `initialized_variable_scopes`. The check is `False` if and only if this is
    the first time the wrapper has been called with the key, otherwise the
    check is `True`. The result of this check is used as the `reuse` flag for
    entering the provided variable scope before calling `method`.

    Here are two examples of how to use the reuse_variables decorator.

    1. Decorate an arbitrary instance method with a `variable_scope` attribute:

      ```python
      class Reusable(object):

        def __init__(self, name):
          with tf.variable_scope(None, default_name=name) as vs:
            self.variable_scope = vs

        @snt.reuse_variables
        def add_a(self, input_tensor):
          a = tf.get_variable("a", shape=input_tensor.get_shape())
          return a + input_tensor

      obj = Reusable("reusable")
      x = tf.constant(5.0)
      out1 = obj.add_a(x)
      out2 = obj.add_a(x)
      # out1 == out2
      ```

    2. Decorating a snt.AbstractModule instance method:

      ```python
      class ReusableModule(snt.AbstractModule):

        @snt.reuse_variables
        def add_a(self, input_tensor):
          a = tf.get_variable("a", shape=input_tensor.get_shape())
          return a + input_tensor

        # We don't need @snt.reuse_variables here because build is
        wrapped by # `tf.make_template` inside `snt.AbstractModule`.
        def _build(self, input_tensor):
          b = tf.get_variable("b", shape=input_tensor.get_shape())
          return b + self.add_a(input_tensor)

      obj = Reusable("reusable")
      x = tf.constant(5.0)
      out1 = obj(x)
      out2 = obj(x)
      # out1 == out2
      ```

    Args:
      obj: The object instance passed to the wrapped method.
      *args: The positional arguments (Tensors) passed to the wrapped method.
      **kwargs: The keyword arguments passed to the wrapped method.

    Returns:
      Output of the wrapped method.

    Raises:
      ValueError: If no variable scope is provided or if `method` is a method
                  and a variable_scope keyword argument is also provided.
    """

    def default_context_manager(reuse=None):
      variable_scope = obj.variable_scope
      return tf.variable_scope(variable_scope, reuse=reuse)

    variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
                                             default_context_manager)

    graph = tf.get_default_graph()
    if graph not in initialized_variable_scopes:
      initialized_variable_scopes[graph] = set()
    initialized_variable_scopes_for_graph = initialized_variable_scopes[graph]

    # Temporarily enter the variable scope to capture it
    with variable_scope_context_manager() as tmp_variable_scope:
      variable_scope = tmp_variable_scope

    reuse = variable_scope.name in initialized_variable_scopes_for_graph

    # Enter the pure variable scope with reuse correctly set
    with variable_scope_ops._pure_variable_scope(  # pylint:disable=protected-access
        variable_scope, reuse=reuse) as pure_variable_scope:
      # Force tf.name_scope to treat variable_scope.original_name_scope as
      # an "absolute" scope name so we can re-enter it.
      name_scope = variable_scope.original_name_scope
      if name_scope[-1] != "/":
        name_scope += "/"
      with tf.name_scope(name_scope):
        sub_scope = to_snake_case(method.__name__)
        with tf.name_scope(sub_scope) as scope:
          if hasattr(obj, "_capture_variables"):
            with obj._capture_variables():  # pylint: disable=protected-access
              out_ops = method(obj, *args, **kwargs)
          else:
            out_ops = method(obj, *args, **kwargs)
          initialized_variable_scopes_for_graph.add(pure_variable_scope.name)
          try:
            # If `obj` is a Sonnet module, let it know it's been connected
            # to the TF graph
            method_positional_args = [obj] + list(args)
            obj._add_connected_subgraph(  # pylint: disable=protected-access
                method, out_ops, scope, *method_positional_args, **kwargs)
          except AttributeError:
            pass
          return out_ops
Exemple #14
0
  def call_method(method, obj, args, kwargs):
    """Calls `method` with a variable scope whose reuse flag is set correctly.

    The first time the wrapper is called it creates a
    `(tf.Graph, tf.VariableScope)` key and checks it for membership in
    `initialized_variable_scopes`. The check is `False` if and only if this is
    the first time the wrapper has been called with the key, otherwise the
    check is `True`. The result of this check is used as the `reuse` flag for
    entering the provided variable scope before calling `method`.

    Here are two examples of how to use the reuse_variables decorator.

    1. Decorate an arbitrary instance method with a `variable_scope` attribute:

      ```python
      class Reusable(object):

        def __init__(self, name):
          with tf.variable_scope(None, default_name=name) as vs:
            self.variable_scope = vs

        @snt.reuse_variables
        def add_a(self, input_tensor):
          a = tf.get_variable("a", shape=input_tensor.get_shape())
          return a + input_tensor

      obj = Reusable("reusable")
      x = tf.constant(5.0)
      out1 = obj.add_a(x)
      out2 = obj.add_a(x)
      # out1 == out2
      ```

    2. Decorating a snt.AbstractModule instance method:

      ```python
      class ReusableModule(snt.AbstractModule):

        @snt.reuse_variables
        def add_a(self, input_tensor):
          a = tf.get_variable("a", shape=input_tensor.get_shape())
          return a + input_tensor

        # We don't need @snt.reuse_variables here because build is
        wrapped by # `tf.make_template` inside `snt.AbstractModule`.
        def _build(self, input_tensor):
          b = tf.get_variable("b", shape=input_tensor.get_shape())
          return b + self.add_a(input_tensor)

      obj = Reusable("reusable")
      x = tf.constant(5.0)
      out1 = obj(x)
      out2 = obj(x)
      # out1 == out2
      ```

    Args:
      method: The method to wrap.
      obj: The object instance passed to the wrapped method.
      args: The positional arguments (Tensors) passed to the wrapped method.
      kwargs: The keyword arguments passed to the wrapped method.

    Returns:
      Output of the wrapped method.

    Raises:
      ValueError: If no variable scope is provided or if `method` is a method
                  and a variable_scope keyword argument is also provided.
    """

    # If @reuse_variables is combined with @property, obj is passed in args
    # and method is still unbound at this stage.
    if obj is None:
      obj = args[0]

    def default_context_manager(reuse=None):
      variable_scope = obj.variable_scope
      return tf.variable_scope(variable_scope, reuse=reuse)

    variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
                                             default_context_manager)

    with ops.init_scope():
      # We need `init_scope` incase we're running inside a defun. In that case
      # what we want is information about where the function will be called not
      # where the function is being built.
      graph = tf.get_default_graph()
      will_call_in_eager_context = tf.executing_eagerly()

    if will_call_in_eager_context:
      initialized_variable_scopes = initialized_variable_scopes_eager
    else:
      if graph not in initialized_variable_scopes_graph:
        initialized_variable_scopes_graph[graph] = set()
      initialized_variable_scopes = initialized_variable_scopes_graph[graph]

    # Temporarily enter the variable scope to capture it
    with variable_scope_context_manager() as tmp_variable_scope:
      variable_scope = tmp_variable_scope

    reuse = variable_scope.name in initialized_variable_scopes

    # Enter the pure variable scope with reuse correctly set
    with variable_scope_ops._pure_variable_scope(  # pylint:disable=protected-access
        variable_scope, reuse=reuse) as pure_variable_scope:
      current_name_scope = tf.get_default_graph().get_name_scope()
      # Force tf.name_scope to treat current_name_scope as an "absolute" scope
      # so we can re-enter it.
      if current_name_scope and current_name_scope[-1] != "/":
        current_name_scope += "/"
      with tf.name_scope(current_name_scope):
        module_name = pure_variable_scope.name
        method_name = to_snake_case(method.__name__)
        method_name_scope = "{}/{}".format(module_name, method_name)
        with tf.name_scope(method_name_scope) as scope:
          if hasattr(obj, "_capture_variables"):
            with obj._capture_variables():  # pylint: disable=protected-access
              out_ops = method(*args, **kwargs)
          else:
            out_ops = method(*args, **kwargs)
      initialized_variable_scopes.add(pure_variable_scope.name)
      try:
        # If `obj` is a Sonnet module, let it know it's been connected
        # to the TF graph.
        obj._is_connected = True  # pylint: disable=protected-access
        if not tf.executing_eagerly():
          obj._add_connected_subgraph(  # pylint: disable=protected-access
              method, out_ops, scope, *args, **kwargs)
      except AttributeError:
        pass
    return out_ops