예제 #1
0
 def test_invalid_writer(self):
     with six.assertRaisesRegex(
         self,
         TypeError,
         "writer must be a `SummaryWriter` or `str`, not None",
     ):
         keras.Callback(writer=None, hparams={})
예제 #2
0
 def test_duplicate_hparam_names_from_two_objects(self):
   hparams = {
       hp.HParam("foo"): 1,
       hp.HParam("foo"): 1,
   }
   with six.assertRaisesRegex(
       self, ValueError, "multiple values specified for hparam 'foo'"):
     keras.Callback(self.get_temp_dir(), hparams)
예제 #3
0
 def test_duplicate_hparam_names_across_object_and_string(self):
     hparams = {
         "foo": 1,
         hp.HParam("foo"): 1,
     }
     with self.assertRaisesRegex(
             ValueError, "multiple values specified for hparam 'foo'"):
         keras.Callback(self.get_temp_dir(), hparams)
예제 #4
0
 def _initialize_model(self, writer):
   HP_DENSE_NEURONS = hp.HParam("dense_neurons", hp.IntInterval(4, 16))
   self.hparams = {
       "optimizer": "adam",
       HP_DENSE_NEURONS: 8,
   }
   self.model = tf.keras.models.Sequential([
       tf.keras.layers.Dense(self.hparams[HP_DENSE_NEURONS], input_shape=(1,)),
       tf.keras.layers.Dense(1, activation="sigmoid"),
   ])
   self.model.compile(loss="mse", optimizer=self.hparams["optimizer"])
   self.callback = keras.Callback(writer, self.hparams)
예제 #5
0
 def test_invalid_trial_id(self):
   with six.assertRaisesRegex(
       self, TypeError, "`trial_id` should be a `str`, but got: 12"):
     keras.Callback(self.get_temp_dir(), {}, trial_id=12)