示例#1
0
    def testNoneHparams(self):
        @registry.register_hparams
        def hp():
            pass

        with self.assertRaisesRegexp(TypeError, "is None"):
            registry.hparams("hp")
示例#2
0
    def testNamedRegistration(self):
        @registry.register_hparams("a")
        def my_hparams_set():
            return 7

        @registry.register_ranged_hparams("a")
        def my_hparams_range(_):
            pass

        self.assertEqual(registry.hparams("a"), my_hparams_set())
        self.assertTrue(registry.ranged_hparams("a") is my_hparams_range)
示例#3
0
    def testHParamSet(self):
        @registry.register_hparams
        def my_hparams_set():
            return 3

        @registry.register_ranged_hparams
        def my_hparams_range(_):
            pass

        self.assertEqual(registry.hparams("my_hparams_set"), my_hparams_set())
        self.assertTrue(
            registry.ranged_hparams("my_hparams_range") is my_hparams_range)
示例#4
0
def create_hparams(hparams_set,
                   hparams_overrides_str="",
                   data_dir=None,
                   problem_name=None,
                   hparams_path=None):
    """Create HParams with data_dir and problem hparams, if kwargs provided."""
    hparams = registry.hparams(hparams_set)
    if hparams_path and tf.gfile.Exists(hparams_path):
        hparams = create_hparams_from_json(hparams_path, hparams)
    if data_dir:
        hparams.add_hparam("data_dir", data_dir)
    if hparams_overrides_str:
        tf.logging.info("Overriding hparams in %s with %s", hparams_set,
                        hparams_overrides_str)
        hparams = hparams.parse(hparams_overrides_str)
    if problem_name:
        add_problem_hparams(hparams, problem_name)
    return hparams
示例#5
0
 def testUnknownHparams(self):
     with self.assertRaisesRegexp(KeyError, "never registered"):
         registry.hparams("not_registered")
     with self.assertRaisesRegexp(KeyError, "never registered"):
         registry.ranged_hparams("not_registered")