def _InputBatch(self): length = tf.reduce_prod(self.shape) counter = summary_utils.StatsCounter('CountingInputGenerator') new_value = tf.cast(counter.IncBy(length), dtype=tf.int32) - length new_value = tf.stop_gradient(new_value) values = new_value + tf.range(length) shaped_values = tf.reshape(tf.cast(values, dtype=tf.float32), self.shape) targets = tf.reduce_sum(shaped_values, axis=0) return py_utils.NestedMap(src_ids=shaped_values, tgt_ids=targets)
def testStatsCounter(self): with self.session() as sess: with cluster_factory.ForTestingWorker(add_summary=True): foo = summary_utils.StatsCounter('foo') val = foo.Value() inc = foo.IncBy(100) tf.global_variables_initializer().run() self.assertAllEqual(0, val.eval()) self.assertAllEqual(100, sess.run(inc)) self.assertAllEqual(100, val.eval()) self.assertAllEqual([100, 200], sess.run([val, inc])) self.assertAllEqual([200, 300], sess.run([val, inc])) summary = tf.Summary.FromString(sess.run(tf.summary.merge_all())) self.assertTrue(any('foo' in v.tag for v in summary.value))