def test_simple_isendrecv(self): tmp = torch.rand(10000000, dtype=torch.double).requires_grad_() req = comm.Isend(tmp,(comm.rank+1)%comm.size,0) res = comm.Recv(mpi4torch.JoinDummies(torch.empty_like(tmp),[req.dummy]),(comm.rank+comm.size-1)%comm.size,0) res2 = comm.Wait(mpi4torch.JoinDummiesHandle(req,[res])) res3 = mpi4torch.JoinDummies(res,[res2]) * comm.rank res3.sum().backward() self.assertTrue((tmp.grad == ((comm.rank + 1 )%comm.size) * torch.ones_like(tmp)).all())
def test_simple_allreduce(self): tmp = torch.rand(10, dtype=torch.double).requires_grad_() tmp2 = torch.rand(10, dtype=torch.double).requires_grad_() tmp3 = torch.rand(10, dtype=torch.double).requires_grad_() res = comm.Allreduce(tmp, mpi4torch.MPI_SUM) res2 = mpi4torch.JoinDummies(res, [tmp2, tmp3]) res2.sum().backward() self.assertTrue((tmp2.grad == torch.zeros(10, dtype=torch.double)).all()) self.assertTrue((tmp3.grad == torch.zeros(10, dtype=torch.double)).all()) self.assertTrue( (tmp.grad == comm.size * torch.ones(10, dtype=torch.double)).all())
import torch import mpi4torch comm = mpi4torch.COMM_WORLD a = torch.tensor([1.0 + comm.rank]).requires_grad_() handle = comm.Isend(a,(comm.rank+1)%comm.size, 0) recvbuffer = mpi4torch.JoinDummies(torch.empty_like(a), [handle.dummy]) b = comm.Recv(recvbuffer, (comm.rank-1+comm.size)%comm.size, 0) wait_ret = comm.Wait(mpi4torch.JoinDummiesHandle(handle,[b])) res = mpi4torch.JoinDummies(a+b, [wait_ret]) print(res) res.backward() print(a.grad)