def SplitInputBatch(self, num_splits): """Splits the current InputBatch into num_splits ways. Args: num_splits: The number of splits. Returns: A list of `.NestedMap`. Each `.NestedMap` represents the input tensors in one split. """ assert num_splits >= 1 batch = self.GetPreprocessedInputBatch() if num_splits == 1: # Special case. No split is needed. return [batch] assert not py_utils.use_tpu() field_split = ig_helper.SplitTensors(batch.Flatten(), num_splits) num_fields = len(field_split) ret = [] for j in range(num_splits): split_flatten = [field_split[i][j] for i in range(num_fields)] split = batch.Pack(split_flatten) ret += [split] return ret
def testSplitTensorsAssert(self): t1 = tf.constant([[1], [7], [8]]) t2 = tf.constant([[5], [9], [10]]) t3 = tf.constant([[13], [14]]) tensor_tuple = (t1, t2, t3) num_splits = 2 with self.assertRaisesRegexp( ValueError, 'can\'t split axis of size 2 into pieces of size \[2,1\]'): splits = input_generator_helper.SplitTensors(tensor_tuple, num_splits)
def testSplitTensorsLessThanNumSplits(self): t1 = tf.constant([[1, 2, 3, 4]]) t2 = tf.constant([[5, 6, 7, 8]]) t3 = tf.constant([[13, 14, 15, 16]]) tensor_tuple = (t1, t2, t3) num_splits = 2 splits = input_generator_helper.SplitTensors(tensor_tuple, num_splits) with self.session(use_gpu=False) as sess: with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, 'first dim of tensors in xs must be greater ' 'than num_splits'): sess.run(splits)
def testSplitTensorsUneven(self): with self.session(use_gpu=False) as sess: t1 = tf.constant([[1], [4], [8]]) t2 = tf.constant([[5], [9], [10]]) t3 = tf.constant([[13], [14], [11]]) tensor_tuple = (t1, t2, t3) num_splits = 2 splits = input_generator_helper.SplitTensors(tensor_tuple, num_splits) expected = ([np.array([[1], [4]]), np.array([[8]])], [np.array([[5], [9]]), np.array([[10]])], [np.array([[13], [14]]), np.array([[11]])]) actual = sess.run(splits) self._assertTupleOfListsEqual(actual, expected)
def testSplitTensorsEven(self): with self.session(use_gpu=False) as sess: t1 = tf.constant([[1, 2, 3, 4], [4, 5, 6, 7]]) t2 = tf.constant([[5, 6, 7, 8], [9, 10, 11, 12]]) t3 = tf.constant([[13, 14, 15, 16], [14, 15, 16, 17]]) tensor_tuple = (t1, t2, t3) num_splits = 2 splits = input_generator_helper.SplitTensors(tensor_tuple, num_splits) expected = ([np.array([[1, 2, 3, 4]]), np.array([[4, 5, 6, 7]])], [np.array([[5, 6, 7, 8]]), np.array([[9, 10, 11, 12]])], [np.array([[13, 14, 15, 16]]), np.array([[14, 15, 16, 17]])]) actual = sess.run(splits) self._assertTupleOfListsEqual(actual, expected)
def SplitInputBatch(self, num_splits): """Splits the current InputBatch into num_splits ways. Args: num_splits: The number of splits. Returns: A list of `.NestedMap`. Each `.NestedMap` represents the input tensors in one split. """ assert num_splits >= 1 print("num_splits " + str(num_splits)) batch = self.GetPreprocessedInputBatch() if num_splits == 1: # Special case. No split is needed. # this is the place the make 1 gpu different from 4 gpu return [batch] assert not py_utils.use_tpu() print("batch " + str(batch)) print("batch.Flatten " + str(batch.Flatten)) print("num_splits " + str(num_splits)) # batch is ok without any ? its this step that get symbol ? field_split = ig_helper.SplitTensors(batch.Flatten(), num_splits) print("field_split " + str(field_split)) num_fields = len(field_split) ret = [] for j in range(num_splits): print("j " + str(j)) split_flatten = [field_split[i][j] for i in range(num_fields)] print("split_flatten " + str(split_flatten)) split = batch.Pack(split_flatten) print("split " + str(split)) ret += [split] return ret