Exemplo n.º 1
0
 def InputBatch(self):
     length = tf.reduce_prod(self.shape)
     counter = base_model.StatsCounter('CountingInputGenerator')
     new_value = tf.cast(counter.IncBy(None, length),
                         dtype=tf.int32) - length
     new_value = tf.stop_gradient(new_value)
     values = new_value + tf.range(length)
     shaped_values = tf.reshape(tf.cast(values, dtype=tf.float32),
                                self.shape)
     targets = tf.reduce_sum(shaped_values, axis=0)
     return py_utils.NestedMap(src_ids=shaped_values, tgt_ids=targets)
Exemplo n.º 2
0
    def testStatsCounter(self):
        with self.session() as sess:
            foo = base_model.StatsCounter('foo')
            val = foo.Value()
            params = base_layer.BaseLayer.Params()
            inc = foo.IncBy(params, 100)

            tf.global_variables_initializer().run()
            self.assertAllEqual(0, val.eval())
            self.assertAllEqual(100, sess.run(inc))
            self.assertAllEqual(100, val.eval())
            self.assertAllEqual([100, 200], sess.run([val, inc]))
            self.assertAllEqual([200, 300], sess.run([val, inc]))