def testNoneHparams(self): @registry.register_hparams def hp(): pass with self.assertRaisesRegexp(TypeError, "is None"): registry.hparams("hp")
def testNamedRegistration(self): @registry.register_hparams("a") def my_hparams_set(): return 7 @registry.register_ranged_hparams("a") def my_hparams_range(_): pass self.assertEqual(registry.hparams("a"), my_hparams_set()) self.assertTrue(registry.ranged_hparams("a") is my_hparams_range)
def testHParamSet(self): @registry.register_hparams def my_hparams_set(): return 3 @registry.register_ranged_hparams def my_hparams_range(_): pass self.assertEqual(registry.hparams("my_hparams_set"), my_hparams_set()) self.assertTrue( registry.ranged_hparams("my_hparams_range") is my_hparams_range)
def create_hparams(hparams_set, hparams_overrides_str="", data_dir=None, problem_name=None, hparams_path=None): """Create HParams with data_dir and problem hparams, if kwargs provided.""" hparams = registry.hparams(hparams_set) if hparams_path and tf.gfile.Exists(hparams_path): hparams = create_hparams_from_json(hparams_path, hparams) if data_dir: hparams.add_hparam("data_dir", data_dir) if hparams_overrides_str: tf.logging.info("Overriding hparams in %s with %s", hparams_set, hparams_overrides_str) hparams = hparams.parse(hparams_overrides_str) if problem_name: add_problem_hparams(hparams, problem_name) return hparams
def testUnknownHparams(self): with self.assertRaisesRegexp(KeyError, "never registered"): registry.hparams("not_registered") with self.assertRaisesRegexp(KeyError, "never registered"): registry.ranged_hparams("not_registered")