def testOneDevice(self): if xla_bridge.device_count() == 1: raise SkipTest("this test requires multiple devices") d0 = xla_bridge.devices()[0] d1 = xla_bridge.devices()[1] f = lambda x: np.dot(x, x.T) f0 = pmap(f, devices=[d0]) f1 = pmap(f, devices=[d1]) x = onp.random.rand(1, 1000, 1000) r0 = f0(x) r1 = f1(x) expected = onp.expand_dims(onp.dot(x.squeeze(), x.squeeze().T), 0) self.assertAllClose(r0, expected, check_dtypes=True, atol=1e-6, rtol=1e-3) self.assertAllClose(r1, expected, check_dtypes=True, atol=1e-6, rtol=1e-3)
def testNestedPmapsError(self): # Devices specified in outer pmap @partial(pmap, axis_name='i', devices=xla_bridge.devices()) def foo(x): @partial(pmap, axis_name='j') def bar(y): return lax.psum(y, 'j') return bar(x) with self.assertRaisesRegex( ValueError, "Nested pmaps with explicit devices argument."): foo(np.ones((xla_bridge.device_count(), 1))) # Devices specified in inner pmap @partial(pmap, axis_name='i') def foo(x): @partial(pmap, axis_name='j', devices=xla_bridge.devices()) def bar(y): return lax.psum(y, 'j') return bar(x) with self.assertRaisesRegex( ValueError, "Nested pmaps with explicit devices argument."): foo(np.ones((xla_bridge.device_count(), 1)))
def train_loop(key, init_params, loss_fn, parallel=True, summarize_fn=default_summarize, lr=1e-4, num_steps=int(1e5), summarize_every=100, checkpoint_every=5000, clobber_checkpoint=False, logdir="/tmp/lda_inference"): if not parallel: train_fn = local_train_loop elif parallel and can_train_parallel(): train_fn = parallel_train_loop else: print( "Platform is %s and num devices is %d, defaulting to local training." % (xla_bridge.get_backend().platform, len(xla_bridge.devices()))) train_fn = local_train_loop train_fn(key, init_params, loss_fn, summarize_fn=summarize_fn, lr=lr, num_steps=num_steps, summarize_every=summarize_every, checkpoint_every=checkpoint_every, clobber_checkpoint=clobber_checkpoint, logdir=logdir)
def testAllDevices(self): f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i', devices=xla_bridge.devices()) shape = (xla_bridge.device_count(), 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) expected = x - onp.sum(x, 0) ans = f(x) self.assertAllClose(ans, expected, check_dtypes=True)
def testGradBasic(self): @partial(pmap, axis_name='i', devices=xla_bridge.devices()) def f(x): return np.sin(x) shape = (xla_bridge.device_count(), 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) ans = grad(lambda x: np.sum(np.sin(x)))(x) expected = grad(lambda x: np.sum(f(x)))(x) self.assertAllClose(ans, expected, check_dtypes=False)
def testJitInPmap(self): @partial(pmap, axis_name='i', devices=xla_bridge.devices()) def foo(x): @jit def bar(y): return y + 1 return lax.psum(bar(x), 'i') ndevices = xla_bridge.device_count() ans = foo(np.ones((ndevices, 1))) expected = onp.ones((ndevices, 1), dtype=np.float_) * ndevices * 2 self.assertAllClose(ans, expected, check_dtypes=True)
def testPmapConstantError(self): device_count = xla_bridge.device_count() f = pmap(lambda x: 3) x = np.arange(device_count + 1) self.assertRaisesRegex( ValueError, r"Cannot replicate across \d+ replicas because only \d+ " r"local devices are available.", lambda: f(x)) f = pmap(lambda x: 3, devices=[xla_bridge.devices()[0]]) x = np.arange(2) self.assertRaisesRegex( ValueError, "Cannot replicate across 2 replicas because only 1 " "local devices are available.", lambda: f(x))
def testPmapConstantDevices(self): if xla_bridge.device_count() == 1: raise SkipTest("this test requires multiple devices") devices = xla_bridge.devices()[:-1] shuffle(devices) f = pmap(lambda x: 3, devices=devices) x = np.arange(len(devices)) ans = f(x) expected = onp.repeat(3, len(devices)) self.assertAllClose(ans, expected, check_dtypes=False) # Test that 'ans' was properly replicated across devices. self.assertEqual([b.device() for b in ans.device_buffers], devices)
def testNestedPmapConstantError(self): f = pmap(pmap(lambda x: 3)) shape = (2, xla_bridge.device_count() // 2 + 1, 3) x = np.arange(prod(shape)).reshape(shape) self.assertRaisesRegex( ValueError, r"Cannot replicate across \d+ replicas because only \d+ " r"local devices are available.", lambda: f(x)) if xla_bridge.device_count() > 1: f = pmap(pmap(lambda x: 3), devices=xla_bridge.devices()[:-1]) shape = (2, xla_bridge.device_count() // 2, 3) x = np.arange(prod(shape)).reshape(shape) self.assertRaisesRegex( ValueError, r"Cannot replicate across \d+ replicas because only \d+ " r"local devices are available.", lambda: f(x))
def testBadAxisSizeError(self): if xla_bridge.device_count() == 1: raise SkipTest("this test requires multiple devices") f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i', devices=xla_bridge.devices()) with self.assertRaisesRegex( ValueError, r"compiling computation that requires 1 replicas, " r"but \d+ devices were specified"): f(np.ones(1)) with self.assertRaisesRegex( ValueError, r"compiling computation that requires \d+ replicas, " r"but \d+ devices were specified"): f(np.ones(xla_bridge.device_count() + 1))
def testBadAxisSizeError(self): if xla_bridge.device_count() == 1: raise SkipTest("this test requires multiple devices") f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i', devices=xla_bridge.devices()) with self.assertRaisesRegex( ValueError, r"Leading axis size of input to pmapped function must " r"equal the number of local devices passed to pmap. Got axis_size=1, " r"num_local_devices=\d."): f(np.ones(1)) with self.assertRaisesRegex( ValueError, r"Leading axis size of input to pmapped function must " r"equal the number of local devices passed to pmap. Got axis_size=\d, " r"num_local_devices=\d."): f(np.ones(xla_bridge.device_count() + 1))
def testReshardInput(self): if xla_bridge.device_count() < 6: raise SkipTest("testReshardInput requires 6 devices") # Manually construct a ShardedDeviceArray with the wrong sharding for the # subsequent pmap shard_shape = (3, 2) shard = np.arange(np.prod(shard_shape)).reshape(shard_shape) bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]] aval = ShapedArray((6, 4), shard.dtype) sharding_spec = pxla.ShardingSpec(shards_per_axis=(2, 2), is_axis_materialized=(True, True), replication_factor=2) arr = pxla.ShardedDeviceArray(aval, sharding_spec, bufs) r = pmap(lambda x: x + 1)(arr) self.assertAllClose(r, arr + 1, check_dtypes=True) self.assertEqual(len(r.device_buffers), 6)
def testNestedPmapConstantDevices(self): raise SkipTest("Nested pmaps with devices not yet implemented") if xla_bridge.device_count() < 6: raise SkipTest("this test requires >= 6 devices") devices = xla_bridge.devices()[:-2] shuffle(devices) f = pmap(pmap(lambda x: 3), devices=devices) shape = (2, len(devices) // 2, 3) x = np.arange(prod(shape)).reshape(shape) ans = f(x) expected = 3 * onp.ones(shape[:2]) self.assertAllClose(ans, expected, check_dtypes=False) # Test that 'ans' was properly replicated across devices. expected_sharded = pmap(pmap(lambda x: x), devices=devices)(expected) self.assertEqual([b.device() for b in ans.device_buffers], [b.device() for b in expected_sharded.device_buffers])
def foo(x): @partial(pmap, axis_name='i', devices=xla_bridge.devices()) def bar(y): return lax.psum(y, 'i') return bar(x)
def can_train_parallel(): return (xla_bridge.get_backend().platform == "tpu" and len(xla_bridge.devices()) > 1)
def test_jit_device(self): device = xb.devices()[-1] x = api.jit(lambda x: x, device=device)(3.) self.assertIsInstance(x, DeviceArray) self.assertEqual(x.device_buffer.device(), device)