예제 #1
0
 def test_eager(self):
     with tf.compat.v2.summary.create_file_writer(self.logdir).as_default():
         result = hp.hparams(
             self.hparams,
             trial_id=self.trial_id,
             start_time_secs=self.start_time_secs,
         )
         self.assertTrue(result)
     self._check_logdir(self.logdir)
예제 #2
0
 def test_graph_mode(self):
     with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session(
     ) as sess, tf.compat.v2.summary.create_file_writer(
             self.logdir).as_default() as w:
         sess.run(w.init())
         summ = hp.hparams(self.hparams,
                           start_time_secs=self.start_time_secs)
         self.assertTrue(sess.run(summ))
         sess.run(w.flush())
     self._check_logdir(self.logdir)
예제 #3
0
 def test_eager_no_default_writer(self):
     result = hp.hparams(self.hparams, start_time_secs=self.start_time_secs)
     self.assertFalse(result)  # no default writer
예제 #4
0
 def on_train_begin(self, logs=None):
     del logs  # unused
     with self._get_writer().as_default():
         summary_v2.hparams(self._hparams, trial_id=self._trial_id)
예제 #5
0
def _write_hparams(hparams):
    global _writer
    with _writer.as_default():
        hp.hparams(hparams)