Beispiel #1
0
  def testNoNestedPmap(self):
    devices = self._get_two_devices(require_same_type=True)

    def f(x):
      return x + 1.0

    data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3]))
    with self.assertRaisesWithPredicateMatch(ValueError,
                                             r"Nested pmap is not supported"):
      f = extensions.pmap(f, devices=devices)
      f = extensions.pmap(f, devices=devices)
      f(data)
Beispiel #2
0
  def testAxisName(self):
    devices = self._get_two_devices(require_same_type=True)

    def reduce_sum(f):
      return extensions.psum(f, axis_name="foo")

    data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3]))
    pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices)
    pmapped(data)
Beispiel #3
0
  def testWrongAxisName(self):
    devices = self._get_two_devices(require_same_type=True)

    def reduce_sum(f):
      return extensions.psum(f, axis_name="bar")

    data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3]))
    with self.assertRaisesWithPredicateMatch(
        ValueError, r"axis_name (.*) is not equal to that of the surrounding"):
      pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices)
      pmapped(data)
Beispiel #4
0
  def testPsum(self):
    devices = self._get_two_devices(require_same_type=True)

    def reduce_sum(f):
      return extensions.psum(f)

    data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3]))
    pmapped = extensions.pmap(reduce_sum, devices=devices)
    result = pmapped(data)

    self.assertAllClose(result[0], 4)
    self.assertAllClose(result[1], 4)
Beispiel #5
0
  def testPmean(self):
    if extensions.tpu_devices():
      self.skipTest("pmean for TPU is not supported yet")
    devices = self._get_two_devices(require_same_type=True)

    def reduce_mean(f):
      return extensions.pmean(f)

    data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3]))
    pmapped = extensions.pmap(reduce_mean, devices=devices)
    result = pmapped(data)

    self.assertAllClose(result[0], 2)
    self.assertAllClose(result[1], 2)
Beispiel #6
0
    def testPmapSimpleModel(self):
        devices = self._get_two_devices(require_same_type=True)
        n_devices = len(devices)

        params, params_true, inputs, targets = generate_params_inputs_targets()

        def _train_and_reduce(params, inputs, targets, learning_rate=0.1):
            new_w, new_b = train_step(params, inputs, targets, learning_rate)

            return (extensions.psum(new_w) / n_devices,
                    extensions.psum(new_b) / n_devices)

        train_step_pmapped = extensions.pmap(_train_and_reduce,
                                             devices=devices)

        def replicate(x, num_devices=2):
            return array_manipulation.broadcast_to(x,
                                                   (num_devices, ) + x.shape)

        params = tf.nest.map_structure(replicate, params)

        def reshape(x, num_devices=2):
            x_shape = list(x.shape)
            batch_size = x_shape[0]
            batch_size_per_device = batch_size // num_devices

            # New shape.
            new_shape_prefix = [num_devices, batch_size_per_device]
            return array_methods.reshape(x, new_shape_prefix + x_shape[1:])

        inputs = tf.nest.map_structure(reshape, inputs)
        targets = tf.nest.map_structure(reshape, targets)

        for _ in range(50):
            params = train_step_pmapped(params, inputs, targets)

        # PMAP returns sharded tensors.

        # Since the inputs are identical, the returned tensors should be identical
        self.assertAllClose(params[0][0], params[0][1])
        self.assertAllClose(params[1][0], params[1][1])

        # This is not trained super well, but it usually gets "close".
        self.assertAllClose(params[0][0], params_true[0], atol=1e-1)
        self.assertAllClose(params[1][0], params_true[1], atol=1e-1)
Beispiel #7
0
  def testPsumStruct(self):
    devices = self._get_two_devices(require_same_type=True)

    def reduce_sum(a):
      a = extensions.psum(a)
      tf.nest.map_structure(
          lambda x: self.assertIsInstance(x, tf_np.ndarray), a)
      return a

    data = [tf_np.asarray([1, 3]), tf_np.asarray([2, 4], np.int64)]
    pmapped = extensions.pmap(reduce_sum, devices=devices)
    result = pmapped(data)

    self.assertIsInstance(result[0][0], tf_np.ndarray)
    self.assertIsInstance(result[0][1], tf_np.ndarray)
    self.assertIsInstance(result[1][0], tf_np.ndarray)
    self.assertIsInstance(result[1][1], tf_np.ndarray)
    self.assertAllClose(result[0][0], 4)
    self.assertAllClose(result[0][1], 4)
    self.assertAllClose(result[1][0], 6)
    self.assertAllClose(result[1][1], 6)
Beispiel #8
0
def _tf_pmap(*args, **kwargs):
    kwargs.pop('donate_argnums', None)  # donate_argnums not used in TF
    return tf_np_extensions.pmap(*args, **kwargs)