def testNamedRegistration(self): @registry.register_hparams("a") def my_hparams_set(): pass @registry.register_ranged_hparams("a") def my_hparams_range(_): pass self.assertTrue(registry.hparams("a") is my_hparams_set) self.assertTrue(registry.ranged_hparams("a") is my_hparams_range)
def testHParamSet(self): @registry.register_hparams def my_hparams_set(): pass @registry.register_ranged_hparams def my_hparams_range(_): pass self.assertTrue(registry.hparams("my_hparams_set") is my_hparams_set) self.assertTrue( registry.ranged_hparams("my_hparams_range") is my_hparams_range)
def testNamedRegistration(self): @registry.register_hparams("a") def my_hparams_set(): return 7 @registry.register_ranged_hparams("a") def my_hparams_range(_): pass self.assertEqual(registry.hparams("a"), my_hparams_set()) self.assertTrue(registry.ranged_hparams("a") is my_hparams_range)
def testHParamSet(self): @registry.register_hparams def my_hparams_set(): return 3 @registry.register_ranged_hparams def my_hparams_range(_): pass self.assertEqual(registry.hparams("my_hparams_set"), my_hparams_set()) self.assertTrue( registry.ranged_hparams("my_hparams_range") is my_hparams_range)
def autotune_paramspecs(hparams_range): rhp = common_hparams.RangedHParams() registry.ranged_hparams(hparams_range)(rhp) return rhp.to_parameter_specs(name_prefix='hp_')
def testUnknownHparams(self): with self.assertRaisesRegexp(LookupError, "never registered"): registry.hparams("not_registered") with self.assertRaisesRegexp(LookupError, "never registered"): registry.ranged_hparams("not_registered")
def autotune_paramspecs(hparams_range): rhp = common_hparams.RangedHParams() registry.ranged_hparams(hparams_range)(rhp) return rhp.to_parameter_specs(name_prefix="hp_")