def test_worker_id(self): n = self.rank + 1 peer_rank = n % self.world_size self_worker_id = dist.get_worker_id() peer_worker_id = dist.get_worker_id('worker{}'.format(peer_rank)) self.assertEqual(self_worker_id.name, 'worker{}'.format(self.rank)) self.assertEqual(peer_worker_id.name, 'worker{}'.format(peer_rank)) with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"): unknown_worker_id = dist.get_worker_id("WorkerUnknown")
def test_add_with_id(self): n = self.rank + 1 dst_rank = n % self.world_size workder_id = dist.get_worker_id('worker{}'.format(dst_rank)) ret = dist.rpc(workder_id, torch.add, args=(torch.ones(n, n), torch.ones(n, n))) self.assertEqual(ret, torch.ones(n, n) * 2)
def test_py_multi_async_call(self): n = self.rank + 1 dst_rank = n % self.world_size dst_worker_id = dist.get_worker_id("worker{}".format(dst_rank)) fut1 = dist.rpc_async(dst_worker_id, my_class.my_static_method, args=(n + 10,)) fut2 = dist.rpc_async(dst_worker_id, min, args=(n, n + 1, n + 2)) self.assertEqual(fut1.wait(), my_class.my_static_method(n + 10)) self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
def test_self_add(self): self_worker_id = dist.get_worker_id() self_worker_name = 'worker{}'.format(self.rank) with self.assertRaisesRegex( RuntimeError, "does not support making RPC calls to self"): dist.rpc(self_worker_id, torch.add, args=(torch.ones(2, 2), 1)) with self.assertRaisesRegex( RuntimeError, "does not support making RPC calls to self"): dist.rpc(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))