def run_init(rank, n_workers): dev = cuda.Device(rank) dev.use() comm = init_process_group(n_workers, rank) # Do a simple call to verify we got a valid comm in_array = cupy.zeros(1) if rank == 0: in_array = in_array + 1 comm.broadcast(in_array, 0) testing.assert_allclose(in_array, cupy.ones(1))
def test_invalid_n_devices(self): with pytest.raises(ValueError): init_process_group(0, 0) with pytest.raises(ValueError): init_process_group(-1, 0)
def test_invalid_rank(self): with pytest.raises(ValueError): init_process_group(2, -1) with pytest.raises(ValueError): init_process_group(2, 3)
def test_invalid_backend(self): with pytest.raises(ValueError): init_process_group(1, 0, backend='mpi')