Ejemplo n.º 1
0
  def testOneDevice(self):
    if xla_bridge.device_count() == 1:
      raise SkipTest("this test requires multiple devices")

    d0 = xla_bridge.devices()[0]
    d1 = xla_bridge.devices()[1]
    f = lambda x: np.dot(x, x.T)
    f0 = pmap(f, devices=[d0])
    f1 = pmap(f, devices=[d1])
    x = onp.random.rand(1, 1000, 1000)
    r0 = f0(x)
    r1 = f1(x)
    expected = onp.expand_dims(onp.dot(x.squeeze(), x.squeeze().T), 0)
    self.assertAllClose(r0, expected, check_dtypes=True, atol=1e-6, rtol=1e-3)
    self.assertAllClose(r1, expected, check_dtypes=True, atol=1e-6, rtol=1e-3)
Ejemplo n.º 2
0
  def testNestedPmapsError(self):
    # Devices specified in outer pmap
    @partial(pmap, axis_name='i', devices=xla_bridge.devices())
    def foo(x):
      @partial(pmap, axis_name='j')
      def bar(y):
        return lax.psum(y, 'j')
      return bar(x)

    with self.assertRaisesRegex(
        ValueError,
        "Nested pmaps with explicit devices argument."):
      foo(np.ones((xla_bridge.device_count(), 1)))

    # Devices specified in inner pmap
    @partial(pmap, axis_name='i')
    def foo(x):
      @partial(pmap, axis_name='j', devices=xla_bridge.devices())
      def bar(y):
        return lax.psum(y, 'j')
      return bar(x)

    with self.assertRaisesRegex(
        ValueError,
        "Nested pmaps with explicit devices argument."):
      foo(np.ones((xla_bridge.device_count(), 1)))
Ejemplo n.º 3
0
def train_loop(key,
               init_params,
               loss_fn,
               parallel=True,
               summarize_fn=default_summarize,
               lr=1e-4,
               num_steps=int(1e5),
               summarize_every=100,
               checkpoint_every=5000,
               clobber_checkpoint=False,
               logdir="/tmp/lda_inference"):

    if not parallel:
        train_fn = local_train_loop
    elif parallel and can_train_parallel():
        train_fn = parallel_train_loop
    else:
        print(
            "Platform is %s and num devices is %d, defaulting to local training."
            % (xla_bridge.get_backend().platform, len(xla_bridge.devices())))
        train_fn = local_train_loop

    train_fn(key,
             init_params,
             loss_fn,
             summarize_fn=summarize_fn,
             lr=lr,
             num_steps=num_steps,
             summarize_every=summarize_every,
             checkpoint_every=checkpoint_every,
             clobber_checkpoint=clobber_checkpoint,
             logdir=logdir)
Ejemplo n.º 4
0
 def testAllDevices(self):
   f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i',
            devices=xla_bridge.devices())
   shape = (xla_bridge.device_count(), 4)
   x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)
   expected = x - onp.sum(x, 0)
   ans = f(x)
   self.assertAllClose(ans, expected, check_dtypes=True)
Ejemplo n.º 5
0
  def testGradBasic(self):
    @partial(pmap, axis_name='i', devices=xla_bridge.devices())
    def f(x):
      return np.sin(x)

    shape = (xla_bridge.device_count(), 4)
    x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)

    ans = grad(lambda x: np.sum(np.sin(x)))(x)
    expected = grad(lambda x: np.sum(f(x)))(x)
    self.assertAllClose(ans, expected, check_dtypes=False)
Ejemplo n.º 6
0
  def testJitInPmap(self):
    @partial(pmap, axis_name='i', devices=xla_bridge.devices())
    def foo(x):
      @jit
      def bar(y):
        return y + 1
      return lax.psum(bar(x), 'i')

    ndevices = xla_bridge.device_count()
    ans = foo(np.ones((ndevices, 1)))
    expected = onp.ones((ndevices, 1), dtype=np.float_) * ndevices * 2
    self.assertAllClose(ans, expected, check_dtypes=True)
Ejemplo n.º 7
0
  def testPmapConstantError(self):
    device_count = xla_bridge.device_count()
    f = pmap(lambda x: 3)
    x = np.arange(device_count + 1)
    self.assertRaisesRegex(
        ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
        r"local devices are available.", lambda: f(x))

    f = pmap(lambda x: 3, devices=[xla_bridge.devices()[0]])
    x = np.arange(2)
    self.assertRaisesRegex(
        ValueError, "Cannot replicate across 2 replicas because only 1 "
        "local devices are available.", lambda: f(x))
Ejemplo n.º 8
0
    def testPmapConstantDevices(self):
        if xla_bridge.device_count() == 1:
            raise SkipTest("this test requires multiple devices")

        devices = xla_bridge.devices()[:-1]
        shuffle(devices)
        f = pmap(lambda x: 3, devices=devices)
        x = np.arange(len(devices))
        ans = f(x)
        expected = onp.repeat(3, len(devices))
        self.assertAllClose(ans, expected, check_dtypes=False)

        # Test that 'ans' was properly replicated across devices.
        self.assertEqual([b.device() for b in ans.device_buffers], devices)
Ejemplo n.º 9
0
  def testNestedPmapConstantError(self):
    f = pmap(pmap(lambda x: 3))
    shape = (2, xla_bridge.device_count() // 2 + 1, 3)
    x = np.arange(prod(shape)).reshape(shape)
    self.assertRaisesRegex(
        ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
        r"local devices are available.", lambda: f(x))

    if xla_bridge.device_count() > 1:
      f = pmap(pmap(lambda x: 3), devices=xla_bridge.devices()[:-1])
      shape = (2, xla_bridge.device_count() // 2, 3)
      x = np.arange(prod(shape)).reshape(shape)
      self.assertRaisesRegex(
          ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
          r"local devices are available.", lambda: f(x))
Ejemplo n.º 10
0
  def testBadAxisSizeError(self):
    if xla_bridge.device_count() == 1:
      raise SkipTest("this test requires multiple devices")

    f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i',
             devices=xla_bridge.devices())
    with self.assertRaisesRegex(
        ValueError, r"compiling computation that requires 1 replicas, "
        r"but \d+ devices were specified"):
      f(np.ones(1))

    with self.assertRaisesRegex(
        ValueError, r"compiling computation that requires \d+ replicas, "
        r"but \d+ devices were specified"):
      f(np.ones(xla_bridge.device_count() + 1))
Ejemplo n.º 11
0
  def testBadAxisSizeError(self):
    if xla_bridge.device_count() == 1:
      raise SkipTest("this test requires multiple devices")

    f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i',
             devices=xla_bridge.devices())
    with self.assertRaisesRegex(
        ValueError, r"Leading axis size of input to pmapped function must "
        r"equal the number of local devices passed to pmap. Got axis_size=1, "
        r"num_local_devices=\d."):
      f(np.ones(1))

    with self.assertRaisesRegex(
        ValueError, r"Leading axis size of input to pmapped function must "
        r"equal the number of local devices passed to pmap. Got axis_size=\d, "
        r"num_local_devices=\d."):
      f(np.ones(xla_bridge.device_count() + 1))
Ejemplo n.º 12
0
    def testReshardInput(self):
        if xla_bridge.device_count() < 6:
            raise SkipTest("testReshardInput requires 6 devices")
        # Manually construct a ShardedDeviceArray with the wrong sharding for the
        # subsequent pmap
        shard_shape = (3, 2)
        shard = np.arange(np.prod(shard_shape)).reshape(shard_shape)
        bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]]
        aval = ShapedArray((6, 4), shard.dtype)
        sharding_spec = pxla.ShardingSpec(shards_per_axis=(2, 2),
                                          is_axis_materialized=(True, True),
                                          replication_factor=2)
        arr = pxla.ShardedDeviceArray(aval, sharding_spec, bufs)

        r = pmap(lambda x: x + 1)(arr)
        self.assertAllClose(r, arr + 1, check_dtypes=True)
        self.assertEqual(len(r.device_buffers), 6)
Ejemplo n.º 13
0
    def testNestedPmapConstantDevices(self):
        raise SkipTest("Nested pmaps with devices not yet implemented")

        if xla_bridge.device_count() < 6:
            raise SkipTest("this test requires >= 6 devices")

        devices = xla_bridge.devices()[:-2]
        shuffle(devices)
        f = pmap(pmap(lambda x: 3), devices=devices)
        shape = (2, len(devices) // 2, 3)
        x = np.arange(prod(shape)).reshape(shape)
        ans = f(x)
        expected = 3 * onp.ones(shape[:2])
        self.assertAllClose(ans, expected, check_dtypes=False)

        # Test that 'ans' was properly replicated across devices.
        expected_sharded = pmap(pmap(lambda x: x), devices=devices)(expected)
        self.assertEqual([b.device() for b in ans.device_buffers],
                         [b.device() for b in expected_sharded.device_buffers])
Ejemplo n.º 14
0
 def foo(x):
   @partial(pmap, axis_name='i', devices=xla_bridge.devices())
   def bar(y):
     return lax.psum(y, 'i')
   return bar(x)
Ejemplo n.º 15
0
def can_train_parallel():
    return (xla_bridge.get_backend().platform == "tpu"
            and len(xla_bridge.devices()) > 1)
Ejemplo n.º 16
0
 def test_jit_device(self):
     device = xb.devices()[-1]
     x = api.jit(lambda x: x, device=device)(3.)
     self.assertIsInstance(x, DeviceArray)
     self.assertEqual(x.device_buffer.device(), device)