コード例 #1
0
 def test_simple_isendrecv(self):
     tmp = torch.rand(10000000, dtype=torch.double).requires_grad_()
     req = comm.Isend(tmp,(comm.rank+1)%comm.size,0)
     res = comm.Recv(mpi4torch.JoinDummies(torch.empty_like(tmp),[req.dummy]),(comm.rank+comm.size-1)%comm.size,0)
     res2 = comm.Wait(mpi4torch.JoinDummiesHandle(req,[res]))
     res3 = mpi4torch.JoinDummies(res,[res2]) * comm.rank
     res3.sum().backward()
     self.assertTrue((tmp.grad == ((comm.rank + 1 )%comm.size) * torch.ones_like(tmp)).all())
コード例 #2
0
 def test_simple_allreduce(self):
     tmp = torch.rand(10, dtype=torch.double).requires_grad_()
     tmp2 = torch.rand(10, dtype=torch.double).requires_grad_()
     tmp3 = torch.rand(10, dtype=torch.double).requires_grad_()
     res = comm.Allreduce(tmp, mpi4torch.MPI_SUM)
     res2 = mpi4torch.JoinDummies(res, [tmp2, tmp3])
     res2.sum().backward()
     self.assertTrue((tmp2.grad == torch.zeros(10,
                                               dtype=torch.double)).all())
     self.assertTrue((tmp3.grad == torch.zeros(10,
                                               dtype=torch.double)).all())
     self.assertTrue(
         (tmp.grad == comm.size * torch.ones(10, dtype=torch.double)).all())
コード例 #3
0
import torch
import mpi4torch

comm = mpi4torch.COMM_WORLD

a = torch.tensor([1.0 + comm.rank]).requires_grad_()

handle = comm.Isend(a,(comm.rank+1)%comm.size, 0)
recvbuffer = mpi4torch.JoinDummies(torch.empty_like(a), [handle.dummy])
b = comm.Recv(recvbuffer, (comm.rank-1+comm.size)%comm.size, 0)
wait_ret = comm.Wait(mpi4torch.JoinDummiesHandle(handle,[b]))

res = mpi4torch.JoinDummies(a+b, [wait_ret])
print(res)

res.backward()
print(a.grad)