def worker(val, shape): rank = dist.get_rank() if rank == 0: # remote send x = tensor(val, device="gpu0") remote_send(x, 1) sync() else: # remote recv y = remote_recv(0, shape, np.float32) assert y.device == "gpu1" np.testing.assert_almost_equal(val, y.numpy())
def worker(val, shape): rank = dist.get_rank() if rank == 0: # remote send x = tensor(val, device="xpu0") remote_send(x, 1) sync() else: # remote recv y = remote_recv(0) assert y.device == get_default_device() np.testing.assert_almost_equal(val, y.numpy())
def worker(): rank = dist.get_rank() if rank == 0: with Grad() as grad: x = as_tensor(x_np) grad.wrt(x, callback=save_to(x)) # need a placeholder to trace operator remote_send(x, 1) recv_x = remote_recv(1) y = recv_x * recv_x grad([y], [as_tensor(np.ones_like(x_np))]) np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2) elif rank == 1: with Grad() as grad: recv_x = remote_recv(0) remote_send(recv_x, 0) grad([], [])
def worker(rank): if mge.get_device_count("gpu") < world_size: return if rank == 0: # remote send dist.init_process_group("localhost", port, world_size, rank, rank) x = Tensor(val, device="gpu0") y = remote_send(x, 1) assert y.numpy()[0] == 0 else: # remote recv dist.init_process_group("localhost", port, world_size, rank, rank) y = remote_recv(0, val.shape, val.dtype) assert y.device == "gpu1" np.testing.assert_almost_equal(val, y.numpy())
def worker1(): dist.init_process_group("localhost", port, world_size, 1, 1) mge.device.set_default_device("gpu1") grad = Grad() recv_x = remote_recv(0, x_np.shape, x_np.dtype, "gpu1") send_x = remote_send(recv_x, 0) grad([], []) # sync because grad has a send operator sync() send_x.device._cn._sync_all()
def worker0(): dist.init_process_group("localhost", port, world_size, 0, 0) mge.device.set_default_device("gpu0") grad = Grad() x = as_tensor(x_np) grad.wrt(x, callback=save_to(x)) # need a placeholder to trace operator send_x = remote_send(x, 1) recv_x = remote_recv(1, x_np.shape, x_np.dtype, "gpu0") y = recv_x * recv_x grad([y], [as_tensor(np.ones_like(x_np))]) np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2)