def testComputeSplitsLessThanNumSplits(self):
    with self.session(use_gpu=False) as sess:
      batch_size = tf.constant(2, dtype=tf.int32)
      num_splits = 4
      splits = input_generator_helper.ComputeSplits(batch_size, num_splits)
      expected = [1, 1, 0, 0]

      actual = sess.run(splits)
      self.assertAllEqual(actual, expected)
  def testComputeSplitsUnevenThree(self):
    with self.session(use_gpu=False) as sess:
      batch_size = tf.constant(29, dtype=tf.int32)
      num_splits = 4
      splits = input_generator_helper.ComputeSplits(batch_size, num_splits)
      expected = [8, 7, 7, 7]

      actual = sess.run(splits)
      self.assertAllEqual(actual, expected)
Ejemplo n.º 3
0
    def testComputeSplitsEven(self):
        with self.session(use_gpu=False):
            batch_size = tf.constant(32, dtype=tf.int32)
            num_splits = 4
            splits = input_generator_helper.ComputeSplits(
                batch_size, num_splits)
            expected = [8, 8, 8, 8]

            actual = self.evaluate(splits)
            self.assertAllEqual(actual, expected)