Ejemplo n.º 1
0
  def testGetVariable(self, use_resource):
    if tf.executing_eagerly() and not use_resource:
      self.skipTest("Ref variables not supported in eager mode.")

    variables = []
    with util.notify_about_new_variables(variables.append):
      with tf.variable_scope("", use_resource=use_resource):
        x = tf.get_variable("x", [])
    self.assertVariableType(x, use_resource)
    self.assertEqual(variables, [x])
Ejemplo n.º 2
0
    def _capture_variables(self):
        """Adds variables used by this module to self._all_variables.

    Upon entering this context manager the module adds itself onto the top
    of the module call stack. Any variables created with `tf.get_variable()`
    inside `_build()` or `_enter_variable_scope()` while this module is on top
    of the call stack will be added to `self._all_variables`.

    Before exiting the context the module removes itself from the top of the
    call stack, and adds all of the variables in `self._all_variables` to its
    parent module (the new top) of the call stack.

    Yields:
      Nothing, the yield just transfers focus back to the inner context.
    """
        module_stack = get_module_stack()
        module_stack.append(self)
        try:
            with contextlib2.ExitStack() as stack:
                # Ideally move re-entering store into Template.variable_scope.
                template_store = getattr(self._template, "_template_store",
                                         None)
                if template_store is not None:
                    # In eager mode, the template store keeps references to created
                    # variables such that they survive even if there are no references to
                    # them in Python code. Variables added to an eager template store are
                    # also added to TensorFlow global collections (unlike regular
                    # variables created in eager mode).
                    stack.enter_context(template_store.as_default())

                stack.enter_context(
                    util.notify_about_new_variables(self._all_variables.add))

                yield

                if self._original_name:
                    self._all_variables.update(self._template.variables)

        finally:
            # Remove `self` from `module_stack`, this happens as part of cleanup
            # even if an error is raised.
            module_stack.pop()

        if module_stack:
            # Peek into the stack to add created variables to the parent
            parent_module = module_stack[-1]
            parent_module._all_variables.update(self._all_variables)  # pylint: disable=protected-access
Ejemplo n.º 3
0
  def testVariableCreatingCustomGetter(self, variable_type, stack_entries):
    use_resource = variable_type == "ResourceVariable"

    if tf.executing_eagerly() and not use_resource:
      self.skipTest("Ref variables not supported in eager mode.")

    def my_custom_getter(getter, **kwargs):
      var = getter(**kwargs)
      # Create an additional variable in the getter which is not returned.
      kwargs["name"] += "_additional"
      getter(**kwargs)
      return var

    variables = []

    with contextlib2.ExitStack() as stack:
      stack.enter_context(tf.variable_scope("", use_resource=use_resource))
      for stack_entry in stack_entries:
        if stack_entry == "notify":
          stack.enter_context(util.notify_about_new_variables(variables.append))
        elif stack_entry == "custom_getter":
          stack.enter_context(
              tf.variable_scope("", custom_getter=my_custom_getter))
        elif stack_entry == "variable_creator":
          stack.enter_context(
              variable_scope_ops.variable_creator_scope(my_custom_getter))
        else:
          raise AssertionError

      v = tf.get_variable("v", [])

    self.assertVariableType(v, use_resource)
    if stack_entries == ["variable_creator", "notify"]:
      # When a variable creator is entered before `notify_about_new_variables`
      # there is no way for us to identify what additional variables that
      # creator created.
      self.assertEqual([v.name for v in variables], [u"v:0"])
    else:
      self.assertEqual([v.name for v in variables], [u"v:0", u"v_additional:0"])
Ejemplo n.º 4
0
 def testNoVariables(self):
   variables = []
   with util.notify_about_new_variables(variables.append):
     pass
   self.assertEqual(variables, [])