def test_invalid_writer(self): with six.assertRaisesRegex( self, TypeError, "writer must be a `SummaryWriter` or `str`, not None", ): keras.Callback(writer=None, hparams={})
def test_duplicate_hparam_names_from_two_objects(self): hparams = { hp.HParam("foo"): 1, hp.HParam("foo"): 1, } with six.assertRaisesRegex( self, ValueError, "multiple values specified for hparam 'foo'"): keras.Callback(self.get_temp_dir(), hparams)
def test_duplicate_hparam_names_across_object_and_string(self): hparams = { "foo": 1, hp.HParam("foo"): 1, } with self.assertRaisesRegex( ValueError, "multiple values specified for hparam 'foo'"): keras.Callback(self.get_temp_dir(), hparams)
def _initialize_model(self, writer): HP_DENSE_NEURONS = hp.HParam("dense_neurons", hp.IntInterval(4, 16)) self.hparams = { "optimizer": "adam", HP_DENSE_NEURONS: 8, } self.model = tf.keras.models.Sequential([ tf.keras.layers.Dense(self.hparams[HP_DENSE_NEURONS], input_shape=(1,)), tf.keras.layers.Dense(1, activation="sigmoid"), ]) self.model.compile(loss="mse", optimizer=self.hparams["optimizer"]) self.callback = keras.Callback(writer, self.hparams)
def test_invalid_trial_id(self): with six.assertRaisesRegex( self, TypeError, "`trial_id` should be a `str`, but got: 12"): keras.Callback(self.get_temp_dir(), {}, trial_id=12)