Exemple #1
0
  def FProp(self, theta, *args):
    """FProp through multiple devices in the split.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: A tuple of Tensors (one or more). Every tensor's first dimension is
        the same (the batch dimension).

    Returns:
      The sub layer's output.
    """
    p = self.params
    with tf.name_scope(p.name):
      assert all(isinstance(x, tf.Tensor) for x in args)
      cluster = self.cluster
      num = cluster.num_devices_per_split
      if num == 1:
        return self.sub.FProp(theta.sub, *args)
      inps = py_utils.SplitRecursively(list(args), num, axis=0)
      outs = []
      for i, xs in enumerate(inps):
        device = cluster.WorkerDeviceInModelSplit(i)
        tf.logging.info('%d on device %s', i, device)
        with tf.device(device):
          ys = self.sub.FProp(theta.sub, *xs)
          if isinstance(ys, tuple):
            outs += [list(ys)]
          else:
            outs += [ys]  # ys is a single tensor
      ret = py_utils.ConcatRecursively(outs, axis=0)
      if isinstance(ret, list):
        return tuple(ret)
      else:
        return ret  # ys is a single tensor
Exemple #2
0
  def testSplitAndConcat(self):
    with self.session():
      # Split a Tensor.
      m3x4 = tf.constant(np.arange(12).reshape([3, 4]))
      splits = py_utils.SplitRecursively(m3x4, 2)
      self.assertEqual(2, len(splits))
      for split in splits:
        self.assertIsInstance(split, tf.Tensor)
      self.assertAllClose([[0, 1], [4, 5], [8, 9]], splits[0].eval())
      self.assertAllClose([[2, 3], [6, 7], [10, 11]], splits[1].eval())
      concatenated = py_utils.ConcatRecursively(splits)
      self.assertAllClose(m3x4.eval(), concatenated.eval())

      # Split along axis 0.
      splits = py_utils.SplitRecursively(m3x4, 3, axis=0)
      self.assertEqual(3, len(splits))
      concatenated = py_utils.ConcatRecursively(splits, axis=0)
      self.assertAllClose(m3x4.eval(), concatenated.eval())
      self.assertAllClose([[0, 1, 2, 3]], splits[0].eval())

      # Split a list.
      list_3 = [m3x4] * 3
      splits = py_utils.SplitRecursively(list_3, 2)
      for split in splits:
        self.assertIsInstance(split, list)
      for x in splits[0]:
        self.assertAllClose([[0, 1], [4, 5], [8, 9]], x.eval())
      for x in splits[1]:
        self.assertAllClose([[2, 3], [6, 7], [10, 11]], x.eval())
      concatenated = py_utils.ConcatRecursively(splits)
      self.assertAllClose([x.eval() for x in list_3],
                          [x.eval() for x in concatenated])

      # Split a NestedMap.
      map_ab = py_utils.NestedMap(a=m3x4, b=list_3)
      splits = py_utils.SplitRecursively(map_ab, 2)
      for split in splits:
        self.assertIsInstance(split, py_utils.NestedMap)
        self.assertIsInstance(split.a, tf.Tensor)
        self.assertIsInstance(split.b, list)
      for x in splits[0].b:
        self.assertAllClose([[0, 1], [4, 5], [8, 9]], x.eval())
      concatenated = py_utils.ConcatRecursively(splits)
      self.assertAllClose(map_ab.a.eval(), concatenated.a.eval())
      self.assertAllClose([x.eval() for x in map_ab.b],
                          [x.eval() for x in concatenated.b])