def worker(): rank = dist.get_rank() size = dist.get_world_size() x = mge.tensor(np.random.randn(1, rank * 2 + 2), dtype=np.float32) m = M.Linear(rank * 2 + 2, rank * 2 + 4) gm = GradManager().attach(m.parameters()) opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) def train_func(x): with gm: if rank != 0: x = dist.functional.remote_recv(rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32) y = m(x) if rank != size - 1: dist.functional.remote_send(y, dest_rank=rank + 1) gm.backward() else: y = y.mean() gm.backward(y) opt.step().clear_grad() train_funcs = [ train_func, trace(symbolic=False)(train_func), trace(symbolic=True)(train_func), ] for func in train_funcs: for i in range(3): func(x) sync()
def worker1(): dist.init_process_group("localhost", port, world_size, 1, 1) mge.device.set_default_device("gpu1") grad = Grad() recv_x = remote_recv(0, x_np.shape, x_np.dtype, "gpu1") send_x = remote_send(recv_x, 0) grad([], []) # sync because grad has a send operator sync() send_x.device._cn._sync_all()
def pytest_runtest_teardown(): sync()