Ejemplo n.º 1
0
 def test_finalize_stats_summaries(self):
   p = plan.Plan(None)
   p.save_summaries_secs = 42
   p.losses['foo'] = tf.constant([1.0])
   p.losses['bar'] = tf.constant([2.0, 3.0])
   p.metrics['baz'] = tf.constant(4)
   p.metrics['qux'] = tf.constant([5.0, 6.0])
   p.finalize_stats()
   with self.test_session():
     self.assertEqual(6, p.loss_total.eval({p.batch_size_placeholder: 1}))
     summary = tf.Summary()
     summary.ParseFromString(p.summaries.eval({p.batch_size_placeholder: 1}))
     qux_string = tf.summary.histogram('qux', [5, 6]).eval()
     qux_proto = tf.Summary()
     qux_proto.ParseFromString(qux_string)
     qux_histogram = qux_proto.value[0].histo
     expected_values = [
         tf.Summary.Value(tag='foo', simple_value=1),
         tf.Summary.Value(tag='bar', simple_value=5),
         tf.Summary.Value(tag='loss_total', simple_value=6),
         tf.Summary.Value(tag='baz', simple_value=4),
         tf.Summary.Value(tag='qux', histo=qux_histogram)]
     six.assertCountEqual(self, expected_values, summary.value)
     summary.ParseFromString(p.summaries.eval({p.batch_size_placeholder: 2}))
     expected_values = [
         tf.Summary.Value(tag='foo', simple_value=0.5),
         tf.Summary.Value(tag='bar', simple_value=2.5),
         tf.Summary.Value(tag='loss_total', simple_value=3),
         tf.Summary.Value(tag='baz', simple_value=4),
         tf.Summary.Value(tag='qux', histo=qux_histogram)]
     six.assertCountEqual(self, expected_values, summary.value)
Ejemplo n.º 2
0
 def test_finalize_stats_invalid_loss_dtype(self):
     p = plan.Plan(None)
     p.losses['foo'] = tf.constant(1)
     self.assertRaisesWithLiteralMatch(
         TypeError,
         'invalid loss dtype tf.int32, must be a floating point type',
         p.finalize_stats)
Ejemplo n.º 3
0
 def test_assert_runnable(self):
   p = plan.Plan(None)
   self.assertRaisesWithLiteralMatch(
       ValueError, 'compiler is required', p.assert_runnable)
   p.compiler = block_compiler.Compiler.create(blocks.Scalar())
   self.assertRaisesWithLiteralMatch(
       ValueError, 'logdir is required', p.assert_runnable)
   p.logdir = '/tmp/'
   p.assert_runnable()
Ejemplo n.º 4
0
 def test_finalize_stats_no_summaries(self):
   p = plan.Plan(None)
   p.losses['foo'] = tf.constant([1.0])
   p.metrics['bar'] = tf.constant(2)
   self.assertEqual(p.summaries, None)
   p.finalize_stats()
   with self.test_session():
     self.assertEqual(1, p.loss_total.eval())
   self.assertRaisesWithLiteralMatch(
       RuntimeError, 'finalize_stats() has already been called',
       p.finalize_stats)
Ejemplo n.º 5
0
 def test_finalize_stats_summaries_empty(self):
   p = plan.Plan(None)
   p.save_summaries_secs = 42
   p.finalize_stats()
   with self.test_session():
     self.assertEqual(b'', p.summaries.eval())