def test_name_scope_duplicate_name(self): with module.name_scope("foo"): mod1 = module.Module(name="bar") with module.name_scope("foo"): mod2 = module.Module(name="bar") self.assertEqual(mod1.module_name, "foo/bar") self.assertEqual(mod2.module_name, "foo_1/bar")
def test_name_scope_reenter(self): scope = module.name_scope("foo") with scope: with self.assertRaisesRegex(ValueError, "name_scope is not reentrant"): with scope: pass
def test_name_scope_reuse(self): scope = module.name_scope("foo") with scope: pass with self.assertRaisesRegex(ValueError, "name_scope is not reusable"): with scope: pass
def test_name_scope_reuse_after_error(self): scope = module.name_scope("foo") with self.assertRaisesRegex(AssertionError, "expected"): with scope: assert False, "expected" with self.assertRaisesRegex(ValueError, "name_scope is not reusable"): with scope: pass
def test_name_scope_reuse(self): # NOTE: If you are considering lifting this restriction, please think # carefully about the following case: # # def f(x): # foo_scope = name_scope("foo") # with foo_scope: x = BarModule()(x) # name: foo/bar_module # with foo_scope: x = BarModule()(x) # name: foo/bar_module # return x # # We believe that the name reuse (when the scope is reused) will surprise # users and lead to bugs. This behaviour does match what would happen if you # put the body of the context manager into a method and called that method # twice. scope = module.name_scope("foo") with scope: pass with self.assertRaisesRegex(ValueError, "name_scope is not reusable"): with scope: pass
def __call__(self): w = base.get_parameter("w", [], init=jnp.zeros) with module.name_scope("foo"): w_foo = base.get_parameter("w", [], init=jnp.zeros) return w, w_foo
def test_name_scope_outside_transform(self): with self.assertRaisesRegex( ValueError, "name_scope.*must be used as part of an `hk.transform`"): module.name_scope("foo")
def test_name_scope_leading_slash(self): with self.assertRaisesRegex(ValueError, "Name scopes must not start with /"): module.name_scope("/foo")
def test_name_scope_nesting(self): with module.name_scope("foo"): with module.name_scope("bar"): mod = module.Module(name="baz") self.assertEqual(mod.module_name, "foo/bar/baz")
def test_name_scope_slash_delimited(self): with module.name_scope("foo/bar"): mod = module.Module(name="baz") self.assertEqual(mod.module_name, "foo/bar/baz")
def test_name_scope_trivial(self): with module.name_scope("foo"): mod1 = module.Module(name="bar") mod2 = module.Module(name="bar") self.assertEqual(mod1.module_name, "foo/bar") self.assertEqual(mod2.module_name, "foo/bar_1")
def test_deferred_naming_name_scope(self): with module.name_scope("foo"): d = deferred.Deferred(ExampleModule) mod = d.target self.assertEqual(mod.module_name, "foo/example_module")