def testGetVariable(self, use_resource): if tf.executing_eagerly() and not use_resource: self.skipTest("Ref variables not supported in eager mode.") variables = [] with util.notify_about_variables(variables.append): with tf.variable_scope("", use_resource=use_resource): x = tf.get_variable("x", []) self.assertVariableType(x, use_resource) self.assertEqual(variables, [x])
def testVariableCreatingCustomGetter(self, variable_type, stack_entries): use_resource = variable_type == "ResourceVariable" if tf.executing_eagerly() and not use_resource: self.skipTest("Ref variables not supported in eager mode.") def my_custom_getter(getter, **kwargs): var = getter(**kwargs) # Create an additional variable in the getter which is not returned. kwargs["name"] += "_additional" getter(**kwargs) return var variables = [] with contextlib2.ExitStack() as stack: stack.enter_context( tf.variable_scope("", use_resource=use_resource)) for stack_entry in stack_entries: if stack_entry == "notify": stack.enter_context( util.notify_about_variables(variables.append)) elif stack_entry == "custom_getter": stack.enter_context( tf.variable_scope("", custom_getter=my_custom_getter)) elif stack_entry == "variable_creator": stack.enter_context( variable_scope_ops.variable_creator_scope( my_custom_getter)) else: raise AssertionError v = tf.get_variable("v", []) self.assertVariableType(v, use_resource) if stack_entries == ["variable_creator", "notify"]: # When a variable creator is entered before `notify_about_variables` there # is no way for us to identify what dditional variables that creator # created. self.assertEqual([v.name for v in variables], [u"v:0"]) else: self.assertEqual([v.name for v in variables], [u"v:0", u"v_additional:0"])
def _capture_variables(self): """Adds variables used by this module to self._all_variables. Upon entering this context manager the module adds itself onto the top of the module call stack. Any variables created with `tf.get_variable()` inside `_build()` or `_enter_variable_scope()` while this module is on top of the call stack will be added to `self._all_variables`. Before exiting the context the module removes itself from the top of the call stack, and adds all of the variables in `self._all_variables` to its parent module (the new top) of the call stack. Yields: Nothing, the yield just transfers focus back to the inner context. """ _MODULE_STACK.append(self) try: with contextlib2.ExitStack() as stack: # Ideally move re-entering store into Template.variable_scope. template_store = getattr(self._template, "_template_store", None) if template_store is not None: # In eager mode, the template store keeps references to created # variables such that they survive even if there are no references to # them in Python code. Variables added to an eager template store are # also added to TensorFlow global collections (unlike regular # variables created in eager mode). stack.enter_context(template_store.as_default()) stack.enter_context( util.notify_about_variables(self._all_variables.add)) yield finally: # Remove `self` from `module_stack`, this happens as part of cleanup # even if an error is raised. _MODULE_STACK.pop() if _MODULE_STACK: # Peek into the stack to add created variables to the parent parent_module = _MODULE_STACK[-1] parent_module._all_variables.update(self._all_variables) # pylint: disable=protected-access
def testVariableCreatingCustomGetter(self, variable_type, stack_entries): use_resource = variable_type == "ResourceVariable" if tf.executing_eagerly() and not use_resource: self.skipTest("Ref variables not supported in eager mode.") def my_custom_getter(getter, **kwargs): var = getter(**kwargs) # Create an additional variable in the getter which is not returned. kwargs["name"] += "_additional" getter(**kwargs) return var variables = [] with contextlib2.ExitStack() as stack: stack.enter_context(tf.variable_scope("", use_resource=use_resource)) for stack_entry in stack_entries: if stack_entry == "notify": stack.enter_context(util.notify_about_variables(variables.append)) elif stack_entry == "custom_getter": stack.enter_context( tf.variable_scope("", custom_getter=my_custom_getter)) elif stack_entry == "variable_creator": stack.enter_context( variable_scope_ops.variable_creator_scope(my_custom_getter)) else: raise AssertionError v = tf.get_variable("v", []) self.assertVariableType(v, use_resource) if stack_entries == ["variable_creator", "notify"]: # When a variable creator is entered before `notify_about_variables` there # is no way for us to identify what dditional variables that creator # created. self.assertEqual([v.name for v in variables], [u"v:0"]) else: self.assertEqual([v.name for v in variables], [u"v:0", u"v_additional:0"])
def testNoVariables(self): variables = [] with util.notify_about_variables(variables.append): pass self.assertEqual(variables, [])