def testPmapConstant(self): device_count = xla_bridge.device_count() f = pmap(lambda x: 3) x = np.arange(device_count) with jtu.count_jit_and_pmap_compiles() as count: ans = f(x) self.assertEqual(count[0], 0) expected = onp.repeat(3, device_count) self.assertAllClose(ans, expected, check_dtypes=False) f = pmap(lambda x: (x, 3)) with jtu.count_jit_and_pmap_compiles() as count: _, ans = f(x) self.assertEqual(count[0], 1) self.assertAllClose(ans, expected, check_dtypes=False)
def testCompilationCache(self): f = lambda x: x + 1 sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2)) shape = (2, ) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) with jtu.count_jit_and_pmap_compiles() as count: sharded_f(x) sharded_f(x) self.assertEqual(count[0], 1)
def testPmapConstantDevices(self): if xla_bridge.device_count() == 1: raise SkipTest("this test requires multiple devices") devices = xla_bridge.devices()[:-1] shuffle(devices) f = pmap(lambda x: 3, devices=devices) x = np.arange(len(devices)) with jtu.count_jit_and_pmap_compiles() as count: ans = f(x) self.assertEqual(count[0], 0) expected = onp.repeat(3, len(devices)) self.assertAllClose(ans, expected, check_dtypes=False) # Test that 'ans' was properly replicated across devices. self.assertEqual([b.device() for b in ans.device_buffers], devices)
def testNestedPmapConstantDevices(self): raise SkipTest("Nested pmaps with devices not yet implemented") if xla_bridge.device_count() < 6: raise SkipTest("this test requires >= 6 devices") devices = xla_bridge.devices()[:-2] shuffle(devices) f = pmap(pmap(lambda x: 3), devices=devices) shape = (2, len(devices) // 2, 3) x = np.arange(prod(shape)).reshape(shape) with jtu.count_jit_and_pmap_compiles() as count: ans = f(x) self.assertEqual(count[0], 0) expected = 3 * onp.ones(shape[:2]) self.assertAllClose(ans, expected, check_dtypes=False) # Test that 'ans' was properly replicated across devices. expected_sharded = pmap(pmap(lambda x: x), devices=devices)(expected) self.assertEqual([b.device() for b in ans.device_buffers], [b.device() for b in expected_sharded.device_buffers])
def testNestedPmapConstant(self): if xla_bridge.device_count() == 1: raise SkipTest("this test requires multiple devices") f = pmap(pmap(lambda x: 3)) shape = (2, xla_bridge.device_count() // 2, 3) x = np.arange(prod(shape)).reshape(shape) with jtu.count_jit_and_pmap_compiles() as count: ans = f(x) self.assertEqual(count[0], 0) expected = 3 * onp.ones(shape[:2]) self.assertAllClose(ans, expected, check_dtypes=False) # Test that 'ans' was properly replicated across devices. expected_sharded = pmap(pmap(lambda x: x))(expected) self.assertEqual([b.device() for b in ans.device_buffers], [b.device() for b in expected_sharded.device_buffers]) f = pmap(pmap(lambda x: (x, 3))) x_sharded, ans = f(x) self.assertAllClose(ans, expected, check_dtypes=False) self.assertEqual([b.device() for b in ans.device_buffers], [b.device() for b in x_sharded.device_buffers])