示例#1
0
 def testAddCounter(self):
     a = tf.placeholder(tf.int64, shape=[])
     hooks.add_counter("sum_a", a)
     sum_a = tf.get_collection(hooks._DEFAULT_COUNTERS_COLLECTION)
     self.assertIsNotNone(sum_a)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         self.assertEqual([6], sess.run(sum_a, feed_dict={a: 6}))
         self.assertEqual([10], sess.run(sum_a, feed_dict={a: 4}))
示例#2
0
    def _register_word_counters(self, features, labels):
        """Creates word counters for sequences (if any) of :obj:`features` and
    :obj:`labels`.
    """
        features_length = self._get_features_length(features)
        labels_length = self._get_labels_length(labels)

        with tf.variable_scope("words_per_sec"):
            if features_length is not None:
                add_counter("features", tf.reduce_sum(features_length))
            if labels_length is not None:
                add_counter("labels", tf.reduce_sum(labels_length))