Beispiel #1
0
 def test_convert_hyperparams_to_hparams_fixed_bool(self):
     hps = hp_module.HyperParameters()
     hps.Fixed("condition", True)
     hparams = utils.convert_hyperparams_to_hparams(hps)
     expected_hparams = {
         hparams_api.HParam("condition", hparams_api.Discrete([True])): True,
     }
     self.assertEqual(repr(hparams), repr(expected_hparams))
Beispiel #2
0
 def test_convert_hyperparams_to_hparams_fixed(self, name, value):
     hps = hp_module.HyperParameters()
     hps.Fixed(name, value)
     hparams = utils.convert_hyperparams_to_hparams(hps)
     expected_hparams = {
         hparams_api.HParam(name, hparams_api.Discrete([value])): value,
     }
     self.assertEqual(repr(hparams), repr(expected_hparams))
Beispiel #3
0
 def test_convert_hyperparams_to_hparams_boolean(self):
     hps = hp_module.HyperParameters()
     hps.Boolean("has_beta")
     hparams = utils.convert_hyperparams_to_hparams(hps)
     expected_hparams = {
         hparams_api.HParam("has_beta", hparams_api.Discrete([True, False])):
             False,
     }
     self.assertEqual(repr(hparams), repr(expected_hparams))
Beispiel #4
0
 def test_convert_hyperparams_to_hparams_choice(self):
     hps = hp_module.HyperParameters()
     hps.Choice("learning_rate", [1e-4, 1e-3, 1e-2])
     hparams = utils.convert_hyperparams_to_hparams(hps)
     expected_hparams = {
         hparams_api.HParam("learning_rate",
                            hparams_api.Discrete([1e-4, 1e-3, 1e-2])): 1e-4,
     }
     self.assertEqual(repr(hparams), repr(expected_hparams))
Beispiel #5
0
 def test_convert_hyperparams_to_hparams_float(self, name, min_value,
                                               max_value, step,
                                               expected_domain,
                                               expected_value):
     hps = hp_module.HyperParameters()
     hps.Float(name, min_value=min_value, max_value=max_value, step=step)
     hparams = utils.convert_hyperparams_to_hparams(hps)
     expected_hparams = {
         hparams_api.HParam(name, expected_domain): expected_value,
     }
     self.assertEqual(repr(hparams), repr(expected_hparams))
Beispiel #6
0
 def test_convert_hyperparams_to_hparams_multi_float(self):
     hps = hp_module.HyperParameters()
     hps.Float("theta", min_value=0.0, max_value=1.57)
     hps.Float("r", min_value=0.0, max_value=1.0)
     hparams = utils.convert_hyperparams_to_hparams(hps)
     expected_hparams = {
         hparams_api.HParam("r", hparams_api.RealInterval(0.0, 1.0)): 0.0,
         hparams_api.HParam("theta",
                            hparams_api.RealInterval(0.0, 1.57)): 0.0,
     }
     hparams_repr_list = [repr(hparams[x]) for x in hparams.keys()]
     expected_hparams_repr_list = [
         repr(expected_hparams[x]) for x in expected_hparams.keys()
     ]
     self.assertCountEqual(hparams_repr_list, expected_hparams_repr_list)