def testComputeSplitsLessThanNumSplits(self): with self.session(use_gpu=False) as sess: batch_size = tf.constant(2, dtype=tf.int32) num_splits = 4 splits = input_generator_helper.ComputeSplits(batch_size, num_splits) expected = [1, 1, 0, 0] actual = sess.run(splits) self.assertAllEqual(actual, expected)
def testComputeSplitsUnevenThree(self): with self.session(use_gpu=False) as sess: batch_size = tf.constant(29, dtype=tf.int32) num_splits = 4 splits = input_generator_helper.ComputeSplits(batch_size, num_splits) expected = [8, 7, 7, 7] actual = sess.run(splits) self.assertAllEqual(actual, expected)
def testComputeSplitsEven(self): with self.session(use_gpu=False): batch_size = tf.constant(32, dtype=tf.int32) num_splits = 4 splits = input_generator_helper.ComputeSplits( batch_size, num_splits) expected = [8, 8, 8, 8] actual = self.evaluate(splits) self.assertAllEqual(actual, expected)