def test_device_put_across_devices(self): if xb.device_count() == 1: raise unittest.SkipTest("this test requires multiple devices") d1, d2 = xb.local_devices()[:2] x = api.device_put(onp.array([1, 2, 3]), device=d1) self.assertEqual(x.device_buffer.device(), d1) y = api.device_put(x, device=d2) self.assertEqual(y.device_buffer.device(), d2) # Make sure these don't crash api.device_put(x) api.device_put(y)
def test_local_devices(self): self.assertNotEmpty(xb.local_devices()) with self.assertRaisesRegex(ValueError, "Unknown host_id 100"): xb.local_devices(100) with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"): xb.local_devices(backend="foo")