def InputBatch(self): length = tf.reduce_prod(self.shape) counter = base_model.StatsCounter('CountingInputGenerator') new_value = tf.cast(counter.IncBy(None, length), dtype=tf.int32) - length new_value = tf.stop_gradient(new_value) values = new_value + tf.range(length) shaped_values = tf.reshape(tf.cast(values, dtype=tf.float32), self.shape) targets = tf.reduce_sum(shaped_values, axis=0) return py_utils.NestedMap(src_ids=shaped_values, tgt_ids=targets)
def testStatsCounter(self): with self.session() as sess: foo = base_model.StatsCounter('foo') val = foo.Value() params = base_layer.BaseLayer.Params() inc = foo.IncBy(params, 100) tf.global_variables_initializer().run() self.assertAllEqual(0, val.eval()) self.assertAllEqual(100, sess.run(inc)) self.assertAllEqual(100, val.eval()) self.assertAllEqual([100, 200], sess.run([val, inc])) self.assertAllEqual([200, 300], sess.run([val, inc]))