def SplitInputBatch(self, num_splits):
        """Splits the current InputBatch into num_splits ways.

    Args:
      num_splits: The number of splits.

    Returns:
      A list of `.NestedMap`. Each `.NestedMap` represents the input
      tensors in one split.
    """
        assert num_splits >= 1

        batch = self.GetPreprocessedInputBatch()
        if num_splits == 1:
            # Special case. No split is needed.
            return [batch]

        assert not py_utils.use_tpu()
        field_split = ig_helper.SplitTensors(batch.Flatten(), num_splits)
        num_fields = len(field_split)
        ret = []
        for j in range(num_splits):
            split_flatten = [field_split[i][j] for i in range(num_fields)]
            split = batch.Pack(split_flatten)
            ret += [split]
        return ret
  def testSplitTensorsAssert(self):
    t1 = tf.constant([[1], [7], [8]])
    t2 = tf.constant([[5], [9], [10]])
    t3 = tf.constant([[13], [14]])

    tensor_tuple = (t1, t2, t3)
    num_splits = 2

    with self.assertRaisesRegexp(
        ValueError, 'can\'t split axis of size 2 into pieces of size \[2,1\]'):
      splits = input_generator_helper.SplitTensors(tensor_tuple, num_splits)
  def testSplitTensorsLessThanNumSplits(self):
    t1 = tf.constant([[1, 2, 3, 4]])
    t2 = tf.constant([[5, 6, 7, 8]])
    t3 = tf.constant([[13, 14, 15, 16]])

    tensor_tuple = (t1, t2, t3)
    num_splits = 2
    splits = input_generator_helper.SplitTensors(tensor_tuple, num_splits)

    with self.session(use_gpu=False) as sess:
      with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                   'first dim of tensors in xs must be greater '
                                   'than num_splits'):
        sess.run(splits)
  def testSplitTensorsUneven(self):
    with self.session(use_gpu=False) as sess:
      t1 = tf.constant([[1], [4], [8]])
      t2 = tf.constant([[5], [9], [10]])
      t3 = tf.constant([[13], [14], [11]])

      tensor_tuple = (t1, t2, t3)
      num_splits = 2
      splits = input_generator_helper.SplitTensors(tensor_tuple, num_splits)
      expected = ([np.array([[1], [4]]), np.array([[8]])],
                  [np.array([[5], [9]]), np.array([[10]])],
                  [np.array([[13], [14]]), np.array([[11]])])

      actual = sess.run(splits)
      self._assertTupleOfListsEqual(actual, expected)
  def testSplitTensorsEven(self):
    with self.session(use_gpu=False) as sess:
      t1 = tf.constant([[1, 2, 3, 4], [4, 5, 6, 7]])
      t2 = tf.constant([[5, 6, 7, 8], [9, 10, 11, 12]])
      t3 = tf.constant([[13, 14, 15, 16], [14, 15, 16, 17]])

      tensor_tuple = (t1, t2, t3)
      num_splits = 2
      splits = input_generator_helper.SplitTensors(tensor_tuple, num_splits)
      expected = ([np.array([[1, 2, 3, 4]]), np.array([[4, 5, 6, 7]])],
                  [np.array([[5, 6, 7, 8]]), np.array([[9, 10, 11, 12]])],
                  [np.array([[13, 14, 15, 16]]), np.array([[14, 15, 16, 17]])])

      actual = sess.run(splits)
      self._assertTupleOfListsEqual(actual, expected)
    def SplitInputBatch(self, num_splits):
        """Splits the current InputBatch into num_splits ways.

    Args:
      num_splits: The number of splits.

    Returns:
      A list of `.NestedMap`. Each `.NestedMap` represents the input
      tensors in one split.
    """
        assert num_splits >= 1
        print("num_splits " + str(num_splits))

        batch = self.GetPreprocessedInputBatch()
        if num_splits == 1:
            # Special case. No split is needed.
            # this is the place the make 1 gpu different from 4 gpu
            return [batch]

        assert not py_utils.use_tpu()
        print("batch " + str(batch))
        print("batch.Flatten " + str(batch.Flatten))
        print("num_splits " + str(num_splits))
        # batch is ok without any ? its this step that get symbol ?
        field_split = ig_helper.SplitTensors(batch.Flatten(), num_splits)
        print("field_split " + str(field_split))
        num_fields = len(field_split)
        ret = []
        for j in range(num_splits):
            print("j " + str(j))
            split_flatten = [field_split[i][j] for i in range(num_fields)]
            print("split_flatten " + str(split_flatten))
            split = batch.Pack(split_flatten)
            print("split " + str(split))
            ret += [split]
        return ret