Beispiel #1
0
    def test_worker_id(self):
        n = self.rank + 1
        peer_rank = n % self.world_size
        self_worker_id = dist.get_worker_id()
        peer_worker_id = dist.get_worker_id('worker{}'.format(peer_rank))

        self.assertEqual(self_worker_id.name, 'worker{}'.format(self.rank))
        self.assertEqual(peer_worker_id.name, 'worker{}'.format(peer_rank))

        with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"):
            unknown_worker_id = dist.get_worker_id("WorkerUnknown")
Beispiel #2
0
    def test_add_with_id(self):
        n = self.rank + 1
        dst_rank = n % self.world_size
        workder_id = dist.get_worker_id('worker{}'.format(dst_rank))

        ret = dist.rpc(workder_id, torch.add,
                       args=(torch.ones(n, n), torch.ones(n, n)))
        self.assertEqual(ret, torch.ones(n, n) * 2)
Beispiel #3
0
 def test_py_multi_async_call(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     dst_worker_id = dist.get_worker_id("worker{}".format(dst_rank))
     fut1 = dist.rpc_async(dst_worker_id, my_class.my_static_method, args=(n + 10,))
     fut2 = dist.rpc_async(dst_worker_id, min, args=(n, n + 1, n + 2))
     self.assertEqual(fut1.wait(), my_class.my_static_method(n + 10))
     self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
Beispiel #4
0
    def test_self_add(self):
        self_worker_id = dist.get_worker_id()
        self_worker_name = 'worker{}'.format(self.rank)

        with self.assertRaisesRegex(
                RuntimeError, "does not support making RPC calls to self"):
            dist.rpc(self_worker_id, torch.add, args=(torch.ones(2, 2), 1))

        with self.assertRaisesRegex(
                RuntimeError, "does not support making RPC calls to self"):
            dist.rpc(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))