def test_finalize_stats_summaries(self): p = plan.Plan(None) p.save_summaries_secs = 42 p.losses['foo'] = tf.constant([1.0]) p.losses['bar'] = tf.constant([2.0, 3.0]) p.metrics['baz'] = tf.constant(4) p.metrics['qux'] = tf.constant([5.0, 6.0]) p.finalize_stats() with self.test_session(): self.assertEqual(6, p.loss_total.eval({p.batch_size_placeholder: 1})) summary = tf.Summary() summary.ParseFromString(p.summaries.eval({p.batch_size_placeholder: 1})) qux_string = tf.summary.histogram('qux', [5, 6]).eval() qux_proto = tf.Summary() qux_proto.ParseFromString(qux_string) qux_histogram = qux_proto.value[0].histo expected_values = [ tf.Summary.Value(tag='foo', simple_value=1), tf.Summary.Value(tag='bar', simple_value=5), tf.Summary.Value(tag='loss_total', simple_value=6), tf.Summary.Value(tag='baz', simple_value=4), tf.Summary.Value(tag='qux', histo=qux_histogram)] six.assertCountEqual(self, expected_values, summary.value) summary.ParseFromString(p.summaries.eval({p.batch_size_placeholder: 2})) expected_values = [ tf.Summary.Value(tag='foo', simple_value=0.5), tf.Summary.Value(tag='bar', simple_value=2.5), tf.Summary.Value(tag='loss_total', simple_value=3), tf.Summary.Value(tag='baz', simple_value=4), tf.Summary.Value(tag='qux', histo=qux_histogram)] six.assertCountEqual(self, expected_values, summary.value)
def test_finalize_stats_invalid_loss_dtype(self): p = plan.Plan(None) p.losses['foo'] = tf.constant(1) self.assertRaisesWithLiteralMatch( TypeError, 'invalid loss dtype tf.int32, must be a floating point type', p.finalize_stats)
def test_assert_runnable(self): p = plan.Plan(None) self.assertRaisesWithLiteralMatch( ValueError, 'compiler is required', p.assert_runnable) p.compiler = block_compiler.Compiler.create(blocks.Scalar()) self.assertRaisesWithLiteralMatch( ValueError, 'logdir is required', p.assert_runnable) p.logdir = '/tmp/' p.assert_runnable()
def test_finalize_stats_no_summaries(self): p = plan.Plan(None) p.losses['foo'] = tf.constant([1.0]) p.metrics['bar'] = tf.constant(2) self.assertEqual(p.summaries, None) p.finalize_stats() with self.test_session(): self.assertEqual(1, p.loss_total.eval()) self.assertRaisesWithLiteralMatch( RuntimeError, 'finalize_stats() has already been called', p.finalize_stats)
def test_finalize_stats_summaries_empty(self): p = plan.Plan(None) p.save_summaries_secs = 42 p.finalize_stats() with self.test_session(): self.assertEqual(b'', p.summaries.eval())