def test_reuse_root(self): root = tf.get_variable_scope() with auto_reuse_variables(root) as vs: self.assertFalse(vs.reuse) v0 = self._check_vs('v0', '', '', 'v0:0', 'op:0') with tf.variable_scope('a'): with auto_reuse_variables(root) as vs: self.assertTrue(vs.reuse) v0_1 = self._check_vs('v0', '', '', 'v0:0', 'a/op:0') self.assertIs(v0_1, v0)
def test_reopen_name_scope(self): with auto_reuse_variables('a') as a: self.assertFalse(a.reuse) v1 = self._check_vs('v1', 'a', 'a/', 'a/v1:0', 'a/op:0') with auto_reuse_variables(a, reopen_name_scope=True) as vs: self.assertTrue(vs.reuse) v1_2 = self._check_vs('v1', 'a', 'a/', 'a/v1:0', 'a/op_1:0') self.assertIs(v1_2, v1) with self.assertRaisesRegex( ValueError, 'Variable a/v2 does not exist, or was not ' 'created with tf.get_variable()'): tf.get_variable('v2', shape=())
def test_different_graph(self): with tf.Graph().as_default(): with auto_reuse_variables('a') as a: self.assertFalse(a.reuse) v1 = self._check_vs('v1', 'a', 'a/', 'a/v1:0', 'a/op:0') with auto_reuse_variables('a') as vs: self.assertTrue(vs.reuse) v1_2 = self._check_vs('v1', 'a', 'a_1/', 'a/v1:0', 'a_1/op:0') self.assertIs(v1_2, v1) with tf.Graph().as_default(): with auto_reuse_variables('a') as vs: self.assertFalse(vs.reuse) v1_3 = self._check_vs('v1', 'a', 'a/', 'a/v1:0', 'a/op:0') self.assertIsNot(v1_3, v1)
def test_errors(self): with pytest.raises(ValueError, match='`name_or_scope` cannot be empty'): with auto_reuse_variables(''): pass with pytest.raises(ValueError, match='`name_or_scope` cannot be empty'): with auto_reuse_variables(None): pass with pytest.raises(ValueError, match='`reopen_name_scope` can be set to True ' 'only if `name_or_scope` is an instance of ' '`tf.VariableScope`'): with auto_reuse_variables('a', reopen_name_scope=True): pass
def test_errors(self): with self.assertRaisesRegex(ValueError, '`name_or_scope` cannot be empty.'): with auto_reuse_variables(''): pass with self.assertRaisesRegex(ValueError, '`name_or_scope` cannot be empty.'): with auto_reuse_variables(None): pass with self.assertRaisesRegex( ValueError, '`reopen_name_scope` can be set to True ' 'only if `name_or_scope` is an instance of ' '`tf.VariableScope`.'): with auto_reuse_variables('a', reopen_name_scope=True): pass
def __call__(self, inputs, **kwargs): with auto_reuse_variables(self.variable_scope, reopen_name_scope=True): # Here `reopen_name_scope` is set to True, so that multiple # calls to the same Module instance will always generate operations # within the original name scope. # However, in order for ``tf.variable_scope(default_name=...)`` # to work properly with variable reusing, we must generate a nested # unique name scope. with tf.name_scope('forward'): return self._forward(inputs, **kwargs)
def test_basic_reuse(self): with auto_reuse_variables('a') as a: self.assertFalse(a.reuse) v1 = self._check_vs('v1', 'a', 'a/', 'a/v1:0', 'a/op:0') with auto_reuse_variables('a') as vs: self.assertTrue(vs.reuse) v1_2 = self._check_vs('v1', 'a', 'a_1/', 'a/v1:0', 'a_1/op:0') self.assertIs(v1_2, v1) with pytest.raises( ValueError, match='Variable a/v2 does not exist, or was not ' 'created with tf.get_variable()'): tf.get_variable('v2', shape=()) with auto_reuse_variables(a): self.assertTrue(vs.reuse) v1_3 = self._check_vs('v1', 'a', 'a/', 'a/v1:0', 'a_2/op:0') self.assertIs(v1_3, v1)
def __call__(self, *args, **kwargs): with auto_reuse_variables(self.variable_scope, reopen_name_scope=True): # Here `reopen_name_scope` is set to True, so that multiple # calls to the same Component instance will not occupy multiple # unique name scopes derived from the original name scope. # # However, in order for ``tf.variable_scope(default_name=...)`` # to work properly along with variable reusing, we must generate # a nested unique name scope, using ``tf.name_scope('build')`` with tf.name_scope('build'): return self._call(*args, **kwargs)
def test_nested_reuse(self): with auto_reuse_variables('a') as a: self.assertFalse(a.reuse) with auto_reuse_variables('b') as b: self.assertFalse(b.reuse) b1 = self._check_vs('v1', 'a/b', 'a/b/', 'a/b/v1:0', 'a/b/op:0') with auto_reuse_variables('b') as vs: self.assertTrue(vs.reuse) b1_2 = self._check_vs('v1', 'a/b', 'a/b_1/', 'a/b/v1:0', 'a/b_1/op:0') self.assertIs(b1_2, b1) with auto_reuse_variables(b) as vs: self.assertTrue(vs.reuse) b1_3 = self._check_vs('v1', 'a/b', 'a/b/', 'a/b/v1:0', 'a/b_2/op:0') self.assertIs(b1_3, b1) with auto_reuse_variables('a/b') as vs: self.assertTrue(vs.reuse) b1_4 = self._check_vs('v1', 'a/b', 'a/b_3/', 'a/b/v1:0', 'a/b_3/op:0') self.assertIs(b1_4, b1) with auto_reuse_variables(b) as vs: self.assertTrue(vs.reuse) # having the name scope 'b' is an absurd behavior # of `tf.variable_scope`, which we may not agree but # have to follow. b1_5 = self._check_vs('v1', 'a/b', 'a/b/', 'a/b/v1:0', 'b/op:0') self.assertIs(b1_5, b1)
def test_mix_reuse_and_variable_scope(self): with tf.variable_scope('a') as a: self.assertFalse(a.reuse) with auto_reuse_variables('b') as b: self.assertFalse(b.reuse) b1 = self._check_vs('v1', 'a/b', 'a/b/', 'a/b/v1:0', 'a/b/op:0') with auto_reuse_variables('b') as vs: self.assertTrue(vs.reuse) b1_2 = self._check_vs('v1', 'a/b', 'a/b_1/', 'a/b/v1:0', 'a/b_1/op:0') self.assertIs(b1_2, b1) with auto_reuse_variables('a') as vs: self.assertFalse(vs.reuse) with auto_reuse_variables('b') as vs: self.assertTrue(vs.reuse) b1_3 = self._check_vs('v1', 'a/b', 'a_1/b/', 'a/b/v1:0', 'a_1/b/op:0') self.assertIs(b1_3, b1)