def test_convert_hyperparams_to_hparams_fixed_bool(self): hps = hp_module.HyperParameters() hps.Fixed("condition", True) hparams = utils.convert_hyperparams_to_hparams(hps) expected_hparams = { hparams_api.HParam("condition", hparams_api.Discrete([True])): True, } self.assertEqual(repr(hparams), repr(expected_hparams))
def test_convert_hyperparams_to_hparams_fixed(self, name, value): hps = hp_module.HyperParameters() hps.Fixed(name, value) hparams = utils.convert_hyperparams_to_hparams(hps) expected_hparams = { hparams_api.HParam(name, hparams_api.Discrete([value])): value, } self.assertEqual(repr(hparams), repr(expected_hparams))
def test_convert_hyperparams_to_hparams_boolean(self): hps = hp_module.HyperParameters() hps.Boolean("has_beta") hparams = utils.convert_hyperparams_to_hparams(hps) expected_hparams = { hparams_api.HParam("has_beta", hparams_api.Discrete([True, False])): False, } self.assertEqual(repr(hparams), repr(expected_hparams))
def test_convert_hyperparams_to_hparams_choice(self): hps = hp_module.HyperParameters() hps.Choice("learning_rate", [1e-4, 1e-3, 1e-2]) hparams = utils.convert_hyperparams_to_hparams(hps) expected_hparams = { hparams_api.HParam("learning_rate", hparams_api.Discrete([1e-4, 1e-3, 1e-2])): 1e-4, } self.assertEqual(repr(hparams), repr(expected_hparams))
def test_convert_hyperparams_to_hparams_float(self, name, min_value, max_value, step, expected_domain, expected_value): hps = hp_module.HyperParameters() hps.Float(name, min_value=min_value, max_value=max_value, step=step) hparams = utils.convert_hyperparams_to_hparams(hps) expected_hparams = { hparams_api.HParam(name, expected_domain): expected_value, } self.assertEqual(repr(hparams), repr(expected_hparams))
def test_convert_hyperparams_to_hparams_multi_float(self): hps = hp_module.HyperParameters() hps.Float("theta", min_value=0.0, max_value=1.57) hps.Float("r", min_value=0.0, max_value=1.0) hparams = utils.convert_hyperparams_to_hparams(hps) expected_hparams = { hparams_api.HParam("r", hparams_api.RealInterval(0.0, 1.0)): 0.0, hparams_api.HParam("theta", hparams_api.RealInterval(0.0, 1.57)): 0.0, } hparams_repr_list = [repr(hparams[x]) for x in hparams.keys()] expected_hparams_repr_list = [ repr(expected_hparams[x]) for x in expected_hparams.keys() ] self.assertCountEqual(hparams_repr_list, expected_hparams_repr_list)