def testGet(self): r = registry.Registry("test_registry", value_transformer=lambda k, v: v()) r["a"] = lambda: "xyz" self.assertEqual(r.get("a"), "xyz") self.assertEqual(r.get("a", 3), "xyz") self.assertIsNone(r.get("b")) self.assertEqual(r.get("b", 3), 3)
def testLen(self): r = registry.Registry("test_registry") self.assertEqual(len(r), 0) r["a"] = lambda: None self.assertEqual(len(r), 1) r["b"] = lambda: 4 self.assertEqual(len(r), 2)
def testNoKeyProvided(self): r = registry.Registry("test") def f(): return 3 r.register(f) self.assertEqual(r["f"](), 3)
def testTransformer(self): r = registry.Registry( "test_registry", value_transformer=lambda x, y: x + y()) r.register(3)(lambda: 5) r.register(10)(lambda: 12) self.assertEqual(r[3], 8) self.assertEqual(r[10], 22) self.assertEqual(set(r.values()), set((8, 22))) self.assertEqual(set(r.items()), set(((3, 8), (10, 22))))
from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensor2tensor.utils import registry from tensor2tensor.utils.registry import * # Adds a subsection to the registries to store specific G2G problems registry.Registries.g2g_problems = registry.Registry("g2g_problems", validator=registry._problem_name_validator, on_set=registry._on_problem_set) registry.Registries.g2g_hparams = registry.Registry("g2g_hparams", value_transformer=registry._hparams_value_transformer) # Defines decorator register_problem = lambda x: registry.register_problem(registry.Registries.g2g_problems.register(x)) register_hparams = lambda x: registry.register_hparams(registry.Registries.g2g_hparams.register(x)) # Overrides registry queries list_g2g_problems = lambda: sorted(Registries.g2g_problems) list_problems = list_g2g_problems list_all_problems = list_base_problems list_g2g_hparams = lambda: sorted(Registries.g2g_hparams) #list_hparams = list_g2g_hparams #list_all_hparams = registry.list_hparams
def testIteration(self): r = registry.Registry("test_registry") r["a"] = lambda: None r["b"] = lambda: 4 self.assertEqual(sorted(r), ["a", "b"])
def testMembership(self): r = registry.Registry("test_registry") r["a"] = lambda: None r["b"] = lambda: 4 self.assertTrue("a" in r) self.assertTrue("b" in r)
def testDefaultKeyFn(self): r = registry.Registry("test", default_key_fn=lambda x: x().upper()) r.register()(lambda: "hello") self.assertEqual(r["HELLO"](), "hello")
def testGetterSetter(self): r = registry.Registry("test_registry") r["hello"] = lambda: "world" r["a"] = lambda: "b" self.assertEqual(r["hello"](), "world") self.assertEqual(r["a"](), "b")