Esempio n. 1
0
    def testGetVariable(self, use_resource):
        if tf.executing_eagerly() and not use_resource:
            self.skipTest("Ref variables not supported in eager mode.")

        variables = []
        with util.notify_about_variables(variables.append):
            with tf.variable_scope("", use_resource=use_resource):
                x = tf.get_variable("x", [])
        self.assertVariableType(x, use_resource)
        self.assertEqual(variables, [x])
Esempio n. 2
0
  def testGetVariable(self, use_resource):
    if tf.executing_eagerly() and not use_resource:
      self.skipTest("Ref variables not supported in eager mode.")

    variables = []
    with util.notify_about_variables(variables.append):
      with tf.variable_scope("", use_resource=use_resource):
        x = tf.get_variable("x", [])
    self.assertVariableType(x, use_resource)
    self.assertEqual(variables, [x])
Esempio n. 3
0
    def testVariableCreatingCustomGetter(self, variable_type, stack_entries):
        use_resource = variable_type == "ResourceVariable"

        if tf.executing_eagerly() and not use_resource:
            self.skipTest("Ref variables not supported in eager mode.")

        def my_custom_getter(getter, **kwargs):
            var = getter(**kwargs)
            # Create an additional variable in the getter which is not returned.
            kwargs["name"] += "_additional"
            getter(**kwargs)
            return var

        variables = []

        with contextlib2.ExitStack() as stack:
            stack.enter_context(
                tf.variable_scope("", use_resource=use_resource))
            for stack_entry in stack_entries:
                if stack_entry == "notify":
                    stack.enter_context(
                        util.notify_about_variables(variables.append))
                elif stack_entry == "custom_getter":
                    stack.enter_context(
                        tf.variable_scope("", custom_getter=my_custom_getter))
                elif stack_entry == "variable_creator":
                    stack.enter_context(
                        variable_scope_ops.variable_creator_scope(
                            my_custom_getter))
                else:
                    raise AssertionError

            v = tf.get_variable("v", [])

        self.assertVariableType(v, use_resource)
        if stack_entries == ["variable_creator", "notify"]:
            # When a variable creator is entered before `notify_about_variables` there
            # is no way for us to identify what dditional variables that creator
            # created.
            self.assertEqual([v.name for v in variables], [u"v:0"])
        else:
            self.assertEqual([v.name for v in variables],
                             [u"v:0", u"v_additional:0"])
Esempio n. 4
0
    def _capture_variables(self):
        """Adds variables used by this module to self._all_variables.

    Upon entering this context manager the module adds itself onto the top
    of the module call stack. Any variables created with `tf.get_variable()`
    inside `_build()` or `_enter_variable_scope()` while this module is on top
    of the call stack will be added to `self._all_variables`.

    Before exiting the context the module removes itself from the top of the
    call stack, and adds all of the variables in `self._all_variables` to its
    parent module (the new top) of the call stack.

    Yields:
      Nothing, the yield just transfers focus back to the inner context.
    """
        _MODULE_STACK.append(self)
        try:
            with contextlib2.ExitStack() as stack:
                # Ideally move re-entering store into Template.variable_scope.
                template_store = getattr(self._template, "_template_store",
                                         None)
                if template_store is not None:
                    # In eager mode, the template store keeps references to created
                    # variables such that they survive even if there are no references to
                    # them in Python code. Variables added to an eager template store are
                    # also added to TensorFlow global collections (unlike regular
                    # variables created in eager mode).
                    stack.enter_context(template_store.as_default())

                stack.enter_context(
                    util.notify_about_variables(self._all_variables.add))

                yield
        finally:
            # Remove `self` from `module_stack`, this happens as part of cleanup
            # even if an error is raised.
            _MODULE_STACK.pop()

        if _MODULE_STACK:
            # Peek into the stack to add created variables to the parent
            parent_module = _MODULE_STACK[-1]
            parent_module._all_variables.update(self._all_variables)  # pylint: disable=protected-access
Esempio n. 5
0
  def _capture_variables(self):
    """Adds variables used by this module to self._all_variables.

    Upon entering this context manager the module adds itself onto the top
    of the module call stack. Any variables created with `tf.get_variable()`
    inside `_build()` or `_enter_variable_scope()` while this module is on top
    of the call stack will be added to `self._all_variables`.

    Before exiting the context the module removes itself from the top of the
    call stack, and adds all of the variables in `self._all_variables` to its
    parent module (the new top) of the call stack.

    Yields:
      Nothing, the yield just transfers focus back to the inner context.
    """
    _MODULE_STACK.append(self)
    try:
      with contextlib2.ExitStack() as stack:
        # Ideally move re-entering store into Template.variable_scope.
        template_store = getattr(self._template, "_template_store", None)
        if template_store is not None:
          # In eager mode, the template store keeps references to created
          # variables such that they survive even if there are no references to
          # them in Python code. Variables added to an eager template store are
          # also added to TensorFlow global collections (unlike regular
          # variables created in eager mode).
          stack.enter_context(template_store.as_default())

        stack.enter_context(
            util.notify_about_variables(self._all_variables.add))

        yield
    finally:
      # Remove `self` from `module_stack`, this happens as part of cleanup
      # even if an error is raised.
      _MODULE_STACK.pop()

    if _MODULE_STACK:
      # Peek into the stack to add created variables to the parent
      parent_module = _MODULE_STACK[-1]
      parent_module._all_variables.update(self._all_variables)  # pylint: disable=protected-access
Esempio n. 6
0
  def testVariableCreatingCustomGetter(self, variable_type, stack_entries):
    use_resource = variable_type == "ResourceVariable"

    if tf.executing_eagerly() and not use_resource:
      self.skipTest("Ref variables not supported in eager mode.")

    def my_custom_getter(getter, **kwargs):
      var = getter(**kwargs)
      # Create an additional variable in the getter which is not returned.
      kwargs["name"] += "_additional"
      getter(**kwargs)
      return var

    variables = []

    with contextlib2.ExitStack() as stack:
      stack.enter_context(tf.variable_scope("", use_resource=use_resource))
      for stack_entry in stack_entries:
        if stack_entry == "notify":
          stack.enter_context(util.notify_about_variables(variables.append))
        elif stack_entry == "custom_getter":
          stack.enter_context(
              tf.variable_scope("", custom_getter=my_custom_getter))
        elif stack_entry == "variable_creator":
          stack.enter_context(
              variable_scope_ops.variable_creator_scope(my_custom_getter))
        else:
          raise AssertionError

      v = tf.get_variable("v", [])

    self.assertVariableType(v, use_resource)
    if stack_entries == ["variable_creator", "notify"]:
      # When a variable creator is entered before `notify_about_variables` there
      # is no way for us to identify what dditional variables that creator
      # created.
      self.assertEqual([v.name for v in variables], [u"v:0"])
    else:
      self.assertEqual([v.name for v in variables], [u"v:0", u"v_additional:0"])
Esempio n. 7
0
 def testNoVariables(self):
     variables = []
     with util.notify_about_variables(variables.append):
         pass
     self.assertEqual(variables, [])
Esempio n. 8
0
 def testNoVariables(self):
   variables = []
   with util.notify_about_variables(variables.append):
     pass
   self.assertEqual(variables, [])