def FProp(self, theta, *args): """FProp through multiple devices in the split. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: A tuple of Tensors (one or more). Every tensor's first dimension is the same (the batch dimension). Returns: The sub layer's output. """ p = self.params with tf.name_scope(p.name): assert all(isinstance(x, tf.Tensor) for x in args) cluster = self.cluster num = cluster.num_devices_per_split if num == 1: return self.sub.FProp(theta.sub, *args) inps = py_utils.SplitRecursively(list(args), num, axis=0) outs = [] for i, xs in enumerate(inps): device = cluster.WorkerDeviceInModelSplit(i) tf.logging.info('%d on device %s', i, device) with tf.device(device): ys = self.sub.FProp(theta.sub, *xs) if isinstance(ys, tuple): outs += [list(ys)] else: outs += [ys] # ys is a single tensor ret = py_utils.ConcatRecursively(outs, axis=0) if isinstance(ret, list): return tuple(ret) else: return ret # ys is a single tensor
def testSplitAndConcat(self): with self.session(): # Split a Tensor. m3x4 = tf.constant(np.arange(12).reshape([3, 4])) splits = py_utils.SplitRecursively(m3x4, 2) self.assertEqual(2, len(splits)) for split in splits: self.assertIsInstance(split, tf.Tensor) self.assertAllClose([[0, 1], [4, 5], [8, 9]], splits[0].eval()) self.assertAllClose([[2, 3], [6, 7], [10, 11]], splits[1].eval()) concatenated = py_utils.ConcatRecursively(splits) self.assertAllClose(m3x4.eval(), concatenated.eval()) # Split along axis 0. splits = py_utils.SplitRecursively(m3x4, 3, axis=0) self.assertEqual(3, len(splits)) concatenated = py_utils.ConcatRecursively(splits, axis=0) self.assertAllClose(m3x4.eval(), concatenated.eval()) self.assertAllClose([[0, 1, 2, 3]], splits[0].eval()) # Split a list. list_3 = [m3x4] * 3 splits = py_utils.SplitRecursively(list_3, 2) for split in splits: self.assertIsInstance(split, list) for x in splits[0]: self.assertAllClose([[0, 1], [4, 5], [8, 9]], x.eval()) for x in splits[1]: self.assertAllClose([[2, 3], [6, 7], [10, 11]], x.eval()) concatenated = py_utils.ConcatRecursively(splits) self.assertAllClose([x.eval() for x in list_3], [x.eval() for x in concatenated]) # Split a NestedMap. map_ab = py_utils.NestedMap(a=m3x4, b=list_3) splits = py_utils.SplitRecursively(map_ab, 2) for split in splits: self.assertIsInstance(split, py_utils.NestedMap) self.assertIsInstance(split.a, tf.Tensor) self.assertIsInstance(split.b, list) for x in splits[0].b: self.assertAllClose([[0, 1], [4, 5], [8, 9]], x.eval()) concatenated = py_utils.ConcatRecursively(splits) self.assertAllClose(map_ab.a.eval(), concatenated.a.eval()) self.assertAllClose([x.eval() for x in map_ab.b], [x.eval() for x in concatenated.b])