예제 #1
0
def _write_hparams_config(log_dir, searchspace):
    HPARAMS = _create_hparams_config(searchspace)
    METRICS = [
        hp.Metric(
            "epoch_acc",
            group="validation",
            display_name="accuracy (val.)",
        ),
        hp.Metric(
            "epoch_loss",
            group="validation",
            display_name="loss (val.)",
        ),
        hp.Metric(
            "epoch_acc",
            group="train",
            display_name="accuracy (train)",
        ),
        hp.Metric(
            "epoch_loss",
            group="train",
            display_name="loss (train)",
        ),
    ]

    with tf.summary.create_file_writer(log_dir).as_default():
        hp.hparams_config(hparams=HPARAMS, metrics=METRICS)
예제 #2
0
 def test_eager_no_default_writer(self):
     result = hp.hparams_config(
         hparams=self.hparams,
         metrics=self.metrics,
         time_created_secs=self.time_created_secs,
     )
     self.assertFalse(result)  # no default writer
예제 #3
0
 def test_eager(self):
     with tf.compat.v2.summary.create_file_writer(self.logdir).as_default():
         result = hp.hparams_config(
             hparams=self.hparams,
             metrics=self.metrics,
             time_created_secs=self.time_created_secs,
         )
         self.assertTrue(result)
     self._check_logdir(self.logdir)
예제 #4
0
 def test_graph_mode(self):
     with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session(
     ) as sess, tf.compat.v2.summary.create_file_writer(
             self.logdir).as_default() as w:
         sess.run(w.init())
         summ = hp.hparams_config(
             hparams=self.hparams,
             metrics=self.metrics,
             time_created_secs=self.time_created_secs,
         )
         self.assertTrue(sess.run(summ))
         sess.run(w.flush())
     self._check_logdir(self.logdir)