def _write_hparams_config(log_dir, searchspace): HPARAMS = _create_hparams_config(searchspace) METRICS = [ hp.Metric( "epoch_acc", group="validation", display_name="accuracy (val.)", ), hp.Metric( "epoch_loss", group="validation", display_name="loss (val.)", ), hp.Metric( "epoch_acc", group="train", display_name="accuracy (train)", ), hp.Metric( "epoch_loss", group="train", display_name="loss (train)", ), ] with tf.summary.create_file_writer(log_dir).as_default(): hp.hparams_config(hparams=HPARAMS, metrics=METRICS)
def test_eager_no_default_writer(self): result = hp.hparams_config( hparams=self.hparams, metrics=self.metrics, time_created_secs=self.time_created_secs, ) self.assertFalse(result) # no default writer
def test_eager(self): with tf.compat.v2.summary.create_file_writer(self.logdir).as_default(): result = hp.hparams_config( hparams=self.hparams, metrics=self.metrics, time_created_secs=self.time_created_secs, ) self.assertTrue(result) self._check_logdir(self.logdir)
def test_graph_mode(self): with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session( ) as sess, tf.compat.v2.summary.create_file_writer( self.logdir).as_default() as w: sess.run(w.init()) summ = hp.hparams_config( hparams=self.hparams, metrics=self.metrics, time_created_secs=self.time_created_secs, ) self.assertTrue(sess.run(summ)) sess.run(w.flush()) self._check_logdir(self.logdir)