예제 #1
0
    def testNamedRegistration(self):
        @registry.register_hparams("a")
        def my_hparams_set():
            pass

        @registry.register_ranged_hparams("a")
        def my_hparams_range(_):
            pass

        self.assertTrue(registry.hparams("a") is my_hparams_set)
        self.assertTrue(registry.ranged_hparams("a") is my_hparams_range)
예제 #2
0
    def testHParamSet(self):
        @registry.register_hparams
        def my_hparams_set():
            pass

        @registry.register_ranged_hparams
        def my_hparams_range(_):
            pass

        self.assertTrue(registry.hparams("my_hparams_set") is my_hparams_set)
        self.assertTrue(
            registry.ranged_hparams("my_hparams_range") is my_hparams_range)
예제 #3
0
  def testNamedRegistration(self):

    @registry.register_hparams("a")
    def my_hparams_set():
      return 7

    @registry.register_ranged_hparams("a")
    def my_hparams_range(_):
      pass

    self.assertEqual(registry.hparams("a"), my_hparams_set())
    self.assertTrue(registry.ranged_hparams("a") is my_hparams_range)
예제 #4
0
  def testHParamSet(self):

    @registry.register_hparams
    def my_hparams_set():
      return 3

    @registry.register_ranged_hparams
    def my_hparams_range(_):
      pass

    self.assertEqual(registry.hparams("my_hparams_set"), my_hparams_set())
    self.assertTrue(
        registry.ranged_hparams("my_hparams_range") is my_hparams_range)
예제 #5
0
def autotune_paramspecs(hparams_range):
  rhp = common_hparams.RangedHParams()
  registry.ranged_hparams(hparams_range)(rhp)
  return rhp.to_parameter_specs(name_prefix='hp_')
예제 #6
0
 def testUnknownHparams(self):
     with self.assertRaisesRegexp(LookupError, "never registered"):
         registry.hparams("not_registered")
     with self.assertRaisesRegexp(LookupError, "never registered"):
         registry.ranged_hparams("not_registered")
예제 #7
0
 def testUnknownHparams(self):
   with self.assertRaisesRegexp(LookupError, "never registered"):
     registry.hparams("not_registered")
   with self.assertRaisesRegexp(LookupError, "never registered"):
     registry.ranged_hparams("not_registered")
예제 #8
0
def autotune_paramspecs(hparams_range):
  rhp = common_hparams.RangedHParams()
  registry.ranged_hparams(hparams_range)(rhp)
  return rhp.to_parameter_specs(name_prefix="hp_")