示例#1
0
 def test_reduce_batch_count_mean_and_var(self):
     x = tf.constant([[[1], [2]], [[3], [4]]], dtype=tf.float32)
     count, mean, var = tf_utils.reduce_batch_count_mean_and_var(
         x, reduce_instance_dims=True)
     with tf.compat.v1.Session():
         self.assertAllEqual(count.eval(), 4)
         self.assertAllEqual(mean.eval(), 2.5)
         self.assertAllEqual(var.eval(), 1.25)
示例#2
0
 def test_reduce_batch_count_mean_and_var_elementwise(self):
     x = tf.constant([[[1], [2]], [[3], [4]]], dtype=tf.float32)
     count, mean, var = tf_utils.reduce_batch_count_mean_and_var(
         x, reduce_instance_dims=False)
     with tf.compat.v1.Session():
         self.assertAllEqual(count.eval(), [[2.], [2.]])
         self.assertAllEqual(mean.eval(), [[2.], [3.]])
         self.assertAllEqual(var.eval(), [[1.], [1.]])
示例#3
0
 def test_reduce_batch_count_mean_and_var_sparse(self):
     x = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]],
                         values=[1., 2., 3., 4.],
                         dense_shape=[2, 4])
     count, mean, var = tf_utils.reduce_batch_count_mean_and_var(
         x, reduce_instance_dims=True)
     with tf.compat.v1.Session():
         self.assertAllEqual(count.eval(), 4)
         self.assertAllEqual(mean.eval(), 2.5)
         self.assertAllEqual(var.eval(), 1.25)
示例#4
0
 def test_reduce_batch_count_mean_and_var_sparse_elementwise(self):
     x = tf.SparseTensor(indices=[[0, 0], [0, 3], [1, 1], [1, 3]],
                         values=[1., 2., 3., 4.],
                         dense_shape=[2, 5])
     count, mean, var = tf_utils.reduce_batch_count_mean_and_var(
         x, reduce_instance_dims=False)
     with tf.compat.v1.Session():
         self.assertAllEqual(count.eval(), [1.0, 1.0, 0.0, 2.0, 0.0])
         self.assertAllEqual(mean.eval(), [1.0, 3.0, 0.0, 3.0, 0.0])
         self.assertAllEqual(var.eval(), [0.0, 0.0, 0.0, 1.0, 0.0])
示例#5
0
 def _reduce_batch_count_mean_and_var(x):
     return tf_utils.reduce_batch_count_mean_and_var(
         x, reduce_instance_dims=reduce_instance_dims)