Exemplo n.º 1
0
 def test_sync_rpc(self):
     dstRank = (self.rank + 1) % self.world_size
     for i in range(20):
         dist.sync_rpc()
         n = i + self.rank + 1
         ret1 = dist.rpc('worker%d' % dstRank, torch.add,
                         args=(torch.ones(n, n), torch.ones(n, n)))
         dist.sync_rpc()
         ret2 = dist.rpc('worker%d' % dstRank, torch.add,
                         args=(torch.ones(n, n), 2))
         dist.sync_rpc()
         self.assertEqual(ret1, torch.ones(n, n) * 2)
         self.assertEqual(ret2, torch.ones(n, n) * 3)
Exemplo n.º 2
0
 def test_sync_rpc(self):
     dst_rank = (self.rank + 1) % self.world_size
     for i in range(20):
         dist.sync_rpc()
         n = i + self.rank + 1
         ret1 = dist.rpc_sync(
             "worker{}".format(dst_rank),
             torch.add,
             args=(torch.ones(n, n), torch.ones(n, n)),
         )
         dist.sync_rpc()
         ret2 = dist.rpc_sync(
             "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
         )
         dist.sync_rpc()
         self.assertEqual(ret1, torch.ones(n, n) * 2)
         self.assertEqual(ret2, torch.ones(n, n) * 3)