def test_nested_context(self): with dist_autograd.context() as context_id: # Nested contexts not supported. with self.assertRaisesRegex( RuntimeError, "Already have an autograd context id for this thread"): with dist_autograd.context() as context_id: pass
def run_master(): # put the two model parts on worker1 and worker2 respectively model = DistResNet(["worker1", "worker2", "worker3", "worker4"]) loss_fn = nn.MSELoss() opt = DistributedOptimizer( optim.SGD, model.parameter_rrefs(), lr=0.05, ) one_hot_indices = torch.LongTensor(batch_size) \ .random_(0, num_classes) \ .view(batch_size, 1) for i in range(num_batches): print(f"Processing batch {i}") # generate random inputs and labels inputs = torch.randn(batch_size, 3, image_w, image_h) labels = torch.zeros(batch_size, num_classes) \ .scatter_(1, one_hot_indices, 1) # The distributed autograd context is the dedicated scope for the # distributed backward pass to store gradients, which can later be # retrieved using the context_id by the distributed optimizer. with dist_autograd.context() as context_id: outputs = model(inputs) dist_autograd.backward(context_id, [loss_fn(outputs, labels)]) opt.step(context_id)
def test_worker_ids_recorded(self): dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} with dist_autograd.context() as context_id: # if no tensors require grad, we do not add the send functions, so # no worker ids should be recorded. t1 = torch.ones(3, 3, requires_grad=False) t2 = torch.zeros(3, 3, requires_grad=False) for dst_rank in dst_ranks: ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) rpc.rpc_sync( "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1) ) # no worker ids should be recorded. ctx = dist_autograd._current_context() worker_ids = ctx._known_worker_ids() self.assertEqual(len(worker_ids), 0) # worker_ids should be recorded when tensors do require grad t1.requires_grad = True t2.requires_grad = True for dst_rank in dst_ranks: ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) rpc.rpc_sync( "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1) ) # all worker_ids in dst_ranks should be recorded. worker_ids = ctx._known_worker_ids() self.assertEqual(len(worker_ids), len(dst_ranks)) self.assertEqual(set(worker_ids), dst_ranks)
def test_rpc_complex_args(self): with dist_autograd.context() as context_id: num_tensors = 10 tensors = [] for i in range(num_tensors): tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0))) ret = rpc.rpc_sync( "worker{}".format(self._next_rank()), torch.stack, args=(tensors,) ) self.assertEqual(torch.stack(tensors), ret) # Verify appropriate tensors have been attached the autograd graph. next_funcs = list( dist_autograd._current_context()._send_functions().values() )[0].next_functions idx = 0 for i in range(num_tensors): if i % 2 == 0: self.assertEqual( "torch::autograd::AccumulateGrad", next_funcs[i][0].name() ) self.assertEqual(tensors[i], next_funcs[i][0].variable) else: self.assertIsNone(next_funcs[i][0]) # Verify that the worker id has been recorded in the context ctx = dist_autograd._current_context() worker_ids = ctx._known_worker_ids() self.assertEqual(len(worker_ids), 1) dst_rank = (self.rank + 1) % self.world_size self.assertEqual(worker_ids[0], dst_rank)
def test_graph_for_py_nested_call_itself(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) ret = rpc.rpc_sync("worker{}".format(dst_rank), my_py_nested_call, args=(t1, t2, (self.rank - 1 + self.world_size) % self.world_size, self.world_size, 0)) rpc.rpc_sync("worker{}".format((self.rank + 1) % self.world_size), _set_rpc_done, args=(context_id, 1)) # For self.rank, it has 2 graphs to verify. # One is for current context id when this rank send first rpc # call and execute the torch.add() operator. # Another one is for prev context id when this rank make # nested call. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(2, len(send_functions)) recv_functions = ctx._recv_functions() self.assertEqual(2, len(recv_functions)) self._verify_graph_for_first_rpc_call(list(send_functions.values())[0], list(recv_functions.values())[1], t1, t2, ret) self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1]) # Verify two pairs of send and recv functions for nested # call self._check_rpc_done(1) ctx = dist_autograd._retrieve_context(ctx_ids[1]) self._verify_graph_for_nested_rpc_call(ctx) # this barrier is needed so one worker does not clean up their # autograd context before another worker tries to access it. dist.barrier()
def train(model, optimizer, epoch, data, train_loader): model.train() pbar = tqdm(total=int(data.train_mask.sum())) pbar.set_description(f'Epoch {epoch:02d}') x = data.x y = data.y.squeeze() total_loss = total_correct = 0 for batch_size, n_id, adjs in train_loader: # `adjs` holds a list of `(edge_index, e_id, size)` tuples. with dist_autograd.context() as context_id: adjs = [adj for adj in adjs] out = model(x[n_id], adjs) loss = F.nll_loss(out, y[n_id[:batch_size]]) dist_autograd.backward(context_id, [loss]) optimizer.step(context_id) total_loss += float(loss) total_correct += int( out.argmax(dim=-1).eq(y[n_id[:batch_size]]).sum()) pbar.update(batch_size) pbar.close() loss = total_loss / len(train_loader) approx_acc = total_correct / int(data.train_mask.sum()) return loss, approx_acc
def test_backward_invalid_args(self): with dist_autograd.context() as context_id: with self.assertRaisesRegex(TypeError, "incompatible function arguments"): dist_autograd.backward(None) with self.assertRaisesRegex( RuntimeError, "No tensors provided for gradient computation"): dist_autograd.backward([]) with self.assertRaisesRegex(RuntimeError, "requires_grad not set on"): t = torch.rand(3, 3) dist_autograd.backward([t]) with self.assertRaisesRegex( RuntimeError, "is not a scalar, all roots need to be scalar"): t = torch.rand(3, 3, requires_grad=True) dist_autograd.backward([t]) with self.assertRaisesRegex( RuntimeError, "does not have a valid gradient function"): t = torch.rand(1, requires_grad=True) dist_autograd.backward([t])
def test_backward_autograd_engine_error(self): with dist_autograd.context() as context_id: t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) # Perform some ops before error simulation. tmp = (t1 + t2) * (t1 + t2) t3 = SimulateBackwardError.apply(tmp) # Run multiple round trips across different nodes and verify the # original node receives an error thrown on a node deep in the chain. val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add, args=(t2, t3)) val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.mul, args=(val, t2)) val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.matmul, args=(val, t2)) val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.div, args=(val, t2)) with self.assertRaises(RuntimeError): # Run backwards, and validate we receive an error. dist_autograd.backward([val.sum()])
def run_master(world_size, batch_size, microbatch_size): model = [ RemoteModuleParams(nn.Linear, (784, 100), {}), RemoteModuleParams(nn.ReLU, (), {}), RemoteModuleParams(nn.Linear, (100, 100), {}), RemoteModuleParams(nn.ReLU, (), {}), RemoteModuleParams(nn.Linear, (100, 10), {}), RemoteModuleParams(nn.ReLU, (), {}) ] torch.random.manual_seed(3) loss_fn = DistributedLoss(nn.CrossEntropyLoss) workers = [f"worker{i}/cpu" for i in range(0, world_size)] chunks = ceil(batch_size / microbatch_size) pipe = create_sequence_pipeline(model, [2] * world_size, workers, chunks=chunks) opt = DistributedOptimizer( optim.SGD, pipe.parameter_rrefs(), lr=0.05, ) trainloader, testloader = get_data(batch_size) for i, (inputs, labels) in enumerate(trainloader): if i % 100 == 0: logging.info(f"Processing batch {i}") # evaluate(pipe, testloader) with dist_autograd.context() as context_id: outputs: RRef = pipe(inputs) loss = loss_fn(outputs, RRef(labels)) # block loss.backward(context_id) # will run on last worker opt.step(context_id) evaluate(pipe, testloader)
def auto_graph_extract(devices): from fairscale.experimental.nn.distributed_pipeline.trace import make_graph device = devices[0].split("/")[1] torch.random.manual_seed(3) criterion = DistributedLoss(torch.nn.MSELoss) x = torch.randn(8, 4).to(device) # create model model = nn.Sequential( RemoteModule(devices[0], nn.Linear, (4, 4), {}), ShardedLinearLayer(devices[0], devices, devices[1]), RemoteModule(devices[0], nn.Linear, (4, 4), {}), ) graph = make_graph(model) pipe = DistributedPipeline(graph, chunks=4) partitions = extract_partitions(graph, pipe) assert [[0, 1], [2], [3], [4], [5]] == partitions, f"partitions={partitions}" parameter_rrefs = pipe.parameter_rrefs() assert len(parameter_rrefs) == 8 opt = DistributedOptimizer( torch.optim.SGD, parameter_rrefs, lr=0.05, ) losses = [] for i in range(2): with dist_autograd.context() as context_id: y = pipe(x) loss = criterion(y, rpc.RRef(x)) losses.append(loss) loss.backward(context_id) opt.step(context_id) losses = [l.to_here() for l in losses] assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
def update(devices): torch.random.manual_seed(3) criterion = DistributedLoss(torch.nn.MSELoss) x = torch.randn(8, 4) model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})] pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2]) params = pipe.parameter_rrefs() opt = DistributedOptimizer( torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05, ) losses = [] for i in range(2): with dist_autograd.context() as context_id: y = pipe(x) loss = criterion(y, rpc.RRef(x)) losses.append(loss) loss.backward(context_id) opt.step(context_id) losses = [l.to_here() for l in losses] assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
def test_backward_node_failure(self): with dist_autograd.context() as context_id: t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) res = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add, args=(t1, t2)) # Wait for all RPCs to be done. dist.barrier() # Kill all odd rank nodes. if self.rank % 2 == 0: # Wait a bit for all other nodes to die. time.sleep(5) with self.assertRaisesRegex( RuntimeError, "Request aborted during client shutdown"): # Run backwards, and validate we receive an error since all # other nodes are dead. dist_autograd.backward([res.sum()]) else: # Exit all other nodes. pass
def _train_step(self, formatter): self.model.train() total_loss = 0. total_correct = 0 for batch_idx, (data, target) in enumerate(self.train_loader): with dist_autograd.context() as cid: output = self.model(data) target = target.long().squeeze(1) loss = self.loss_fn(output, target) total_loss += loss.item() correct = (torch.argmax(output, dim=1) == target).sum() total_correct += correct dist_autograd.backward([loss]) # Ensure that dist autograd ran successfully and gradients were # returned. assert remote_method( MasterNetwork.get_dist_gradients, self.model.param_server_rref, cid) != {} self.optimizer.step() print(formatter.train_progress_message(batch_idx=batch_idx, batches=len(self.train_loader), training_examples=len(data), correct=correct, loss=loss.item())) train_loss = total_loss / len(self.train_loader.dataset) train_acc = total_correct / len(self.train_loader.dataset) return train_loss, train_acc
def train(rrefs, kwargs): model = TGCN(rrefs, kwargs['h_size'], kwargs['z_size'], gru_hidden_units=kwargs['n_gru']) opt = DistributedOptimizer(Adam, model.parameter_rrefs(), lr=kwargs['lr']) times = [] best = (None, 0) no_progress = 0 for e in range(kwargs['epochs']): # Get loss and send backward model.train() with dist_autograd.context() as context_id: st = time.time() zs = model.forward(ld.LANL_Data.TRAIN) loss = model.loss_fn(zs, ld.LANL_Data.TRAIN, nratio=kwargs['nratio']) print("backward") dist_autograd.backward(context_id, [loss]) print("step") opt.step(context_id) elapsed = time.time() - st times.append(elapsed) print('[%d] Loss %0.4f %0.2fs' % (e, loss.item(), elapsed)) # Get validation info to prevent overfitting model.eval() with torch.no_grad(): zs = model.forward(ld.LANL_Data.TRAIN, no_grad=True) v_loss = model.loss_fn(zs, ld.LANL_Data.VAL).item() print("\t Val loss: %0.4f" % v_loss) if v_loss > best[1]: best = (model.save_states(), v_loss) else: if e >= kwargs['min']: no_progress += 1 if no_progress == kwargs['patience']: print("Early stopping!") break model.load_states(best[0][0], best[0][1]) zs, h0 = model(ld.LANL_Data.TEST, include_h=True) states = {'gcn': best[0][0], 'rnn': best[0][1]} f = open('model_save.pkl', 'wb+') pickle.dump(states, f, protocol=pickle.HIGHEST_PROTOCOL) print("Exiting train loop") print("Avg TPE: %0.4fs" % (sum(times) / len(times))) return model, zs[-1], h0
def update(devices): device = devices[0].split("/")[1] torch.random.manual_seed(3) criterion = DistributedLoss(torch.nn.MSELoss) x = torch.randn(8, 4).to(device) model = [ RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {}) ] pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=4, devices=devices[:2]) opt = DistributedOptimizer( torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05, ) losses = [] for i in range(2): with dist_autograd.context() as context_id: y = pipe(x) loss = criterion(y, rpc.RRef(x)) losses.append(loss) loss.backward(context_id) opt.step(context_id) losses = [l.to_here() for l in losses] assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
def test_dist_optim(self): # local version module1 = MyModule() module2 = MyModule() params = [module1.get_w(), module2.get_w()] local_optim = optim.SGD(params, lr=0.05) old_w1 = module1.w.clone().detach() old_w2 = module2.w.clone().detach() torch.manual_seed(0) t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) output1 = module1.forward(t2) output2 = module2.forward(output1) loss = torch.add(output2, t1).sum() loss.backward() local_optim.step() # distributed version owner1 = "worker%d" % ((self.rank + 1) % self.world_size) owner2 = "worker%d" % ((self.rank + 2) % self.world_size) remote_module1 = rpc.remote(owner1, MyModule) remote_module2 = rpc.remote(owner2, MyModule) remote_param1 = remote_method(MyModule.get_w, remote_module1) remote_param2 = remote_method(MyModule.get_w, remote_module2) old_w1_remote = remote_param1.to_here() # sanity check: local and remote initial weights should match self.assertEqual(old_w1, remote_param1.to_here()) self.assertEqual(old_w2, remote_param2.to_here()) dist_optim = DistributedOptimizer(optim.SGD, [remote_param1, remote_param2], lr=0.05) with dist_autograd.context(): torch.manual_seed(0) t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) output1 = rpc_async_method(MyModule.forward, remote_module1, t2) output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait()) loss = torch.add(output2.wait(), t1) dist_autograd.backward([loss.sum()]) dist_optim.step() new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait() new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait() # ensure optimizer changed weights self.assertNotEqual(old_w1, new_w1) self.assertNotEqual(old_w2, new_w2) # ensure local equals remote self.assertEqual(new_w1, module1.get_w()) self.assertEqual(new_w2, module2.get_w())
def training_loop(): model = RNN('ps', 3, 10, 1) X_tr, y_tr = gen_toy_data() X_te, y_te = gen_toy_data() opt = DistributedOptimizer(Adam, model.parameter_rrefs(), lr=0.01) loss_fn = nn.MSELoss() for e in range(100): with dist_autograd.context() as context_id: y_hat = model(X_tr) loss = loss_fn(y_hat, y_tr) dist_autograd.backward(context_id, [loss]) opt.step(context_id) # No need to zero grad because it's blown # away every step by the dist API print("[%d] Loss: %0.4f" % (e, loss.item())) y_hat = model(X_te) y_hat[y_hat < 0.5] = 0 y_hat[y_hat >= 0.5] = 1 correct = float((y_hat == y_te).sum().item()) total = float(y_hat.size(1)) print("Final accuracy: %d/%d = %0.4f" % (correct, total, correct / total))
def _run_trainer(rref_t1, t2, ps, rank_diff): with dist_autograd.context() as context_id: ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2)) dist_autograd.backward([ret.sum()]) # prevent deleting dist autograd context rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff)) rpc.rpc_sync(ps, _check_rpc_done, args=(0, ))
def run_master(split_size): # put the two model parts on worker1 and worker2 respectively model = DistResNet50(split_size, ["worker1", "worker2"]) loss_fn = nn.MSELoss() opt = DistributedOptimizer( optim.SGD, model.parameter_rrefs(), lr=0.05, ) one_hot_indices = torch.LongTensor(batch_size) \ .random_(0, num_classes) \ .view(batch_size, 1) for i in range(num_batches): print(f"Processing batch {i}") # generate random inputs and labels inputs = torch.randn(batch_size, 3, image_w, image_h) labels = torch.zeros(batch_size, num_classes) \ .scatter_(1, one_hot_indices, 1) with dist_autograd.context() as context_id: outputs = model(inputs) dist_autograd.backward(context_id, [loss_fn(outputs, labels)]) opt.step(context_id)
def _test_dist_optim_none_grads(self, optim_cls, *args, **kwargs): # local version module1 = MyModule() module2 = MyModule(requires_grad=False) params = [module1.get_w(), module2.get_w()] local_optim = optim_cls(params, *args, **kwargs) old_w1 = module1.w.clone().detach() old_w2 = module2.w.clone().detach() g_cpu = torch.Generator() g_cpu.manual_seed(0) t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) output1 = module1.forward(t2) output2 = module2.forward(output1) loss = torch.add(output2, t1).sum() loss.backward() local_optim.step() # distributed version owner1 = "worker%d" % ((self.rank + 1) % self.world_size) owner2 = "worker%d" % ((self.rank + 2) % self.world_size) remote_module1 = rpc.remote(owner1, MyModule) remote_module2 = rpc.remote(owner2, MyModule, args=(False, )) remote_param1 = remote_module1.remote().get_w() remote_param2 = remote_module2.remote().get_w() # sanity check: local and remote initial weights should match self.assertEqual(old_w1, remote_param1.to_here()) self.assertEqual(old_w2, remote_param2.to_here()) dist_optim = DistributedOptimizer(optim_cls, [remote_param1, remote_param2], *args, **kwargs) with dist_autograd.context() as context_id: g_cpu.manual_seed(0) t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) output1 = remote_module1.rpc_async().forward(t2) output2 = remote_module2.rpc_async().forward(output1.wait()) loss = torch.add(output2.wait(), t1) dist_autograd.backward(context_id, [loss.sum()]) dist_optim.step(context_id) new_w1 = remote_module1.rpc_async().get_w().wait() new_w2 = remote_module2.rpc_async().get_w().wait() # ensure optimizer changed weights for w1 self.assertNotEqual(old_w1, new_w1) # ensure optimizer not changed weights for w2 self.assertEqual(old_w2, new_w2) # ensure local equals remote self.assertEqual(new_w1, module1.get_w()) self.assertEqual(new_w2, module2.get_w())
def test_dist_optim_exception(self): # distributed version owner1 = "worker%d" % ((self.rank + 1) % self.world_size) owner2 = "worker%d" % ((self.rank + 2) % self.world_size) remote_module1 = rpc.remote(owner1, MyModule) remote_module2 = rpc.remote(owner2, MyModule) remote_param1 = remote_method(MyModule.get_w, remote_module1) remote_param2 = remote_method(MyModule.get_w, remote_module2) dist_optim = DistributedOptimizer(FailingOptimizer, [remote_param1, remote_param2]) with dist_autograd.context() as context_id: g_cpu = torch.Generator() g_cpu.manual_seed(0) t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) output1 = rpc_async_method(MyModule.forward, remote_module1, t2) output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait()) loss = torch.add(output2.wait(), t1).sum() dist_autograd.backward(context_id, [loss]) with self.assertRaisesRegex(Exception, "Error running optimizer"): dist_optim.step(context_id)
def test_ddp_dist_autograd_sparse_grads(self): # Each trainer uses a different random seed. Otherwise, they are going # to have exactly the same initial model parameters, input, and # therefore grads. That means the grads will be the same before and # after DDP's all-reduce. torch.manual_seed(self.rank) dist.init_process_group( backend="gloo", init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), world_size=self.world_size, rank=self.rank, ) model = nn.EmbeddingBag(10, 3, sparse=True) ddp_model = DistributedDataParallel(model) # Different inputs for each input = torch.LongTensor(10).random_(0, 10) offsets = torch.LongTensor([0, 4]) # Run local. loss = ddp_model(input, offsets).sum() loss.backward() with dist_autograd.context() as context_id: loss = ddp_model(input, offsets).sum() dist_autograd.backward(context_id, [loss]) grads_dict = dist_autograd.get_gradients(context_id) self.assertEqual(1, len(grads_dict)) self.assertEqual(model.weight.grad, grads_dict[model.weight])
def run_training_loop(rank, num_gpus, train_loader, test_loader): # Runs the typical nueral network forward + backward + optimizer step, but # in a distributed fashion. net = TrainerNet(num_gpus=num_gpus) # Build DistributedOptmizer. param_rrefs = net.get_global_param_rrefs() opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03) for i, (data, target) in enumerate(train_loader): with dist_autograd.context() as cid: model_output = net(data) target = target.to(model_output.device) loss = F.nll_loss(model_output, target) if i % 5 == 0: print(f"Rank {rank} training batch {i} loss {loss.item()}") dist_autograd.backward(cid, [loss]) # Ensure that dist autograd ran successfully and gradients were # returned. assert remote_method( ParameterServer.get_dist_gradients, net.param_server_rref, cid) != {} opt.step(cid) print("Training complete!") print("Getting accuracy....") get_accuracy(test_loader, net)
def test_ddp_dist_autograd_local_vs_remote(self): # Each trainer uses a different random seed. Otherwise, they are going # to have exactly the same initial model parameters, input, and # therefore grads. That means the grads will be the same before and # after DDP's all-reduce. torch.manual_seed(self.rank) dist.init_process_group(backend="gloo", init_method="file://{}".format(self.file_name), world_size=self.world_size, rank=self.rank) remote_layer1 = RemoteModule("worker0", nn.Linear, args=(10, 5, False)) layer1 = nn.Linear(10, 5, False) # Start with the same parameters for remote and local layer1.weight = remote_layer1.module_rref.to_here().weight # Run local case. layer2 = nn.Linear(5, 1) inputs = torch.rand((10, 10)) ddp_model = DistributedDataParallel(layer2) loss = ddp_model(layer1(inputs)).sum() loss.backward() # Run remote case. with dist_autograd.context() as context_id: loss = ddp_model(remote_layer1(inputs)).sum() dist_autograd.backward(context_id, [loss]) grads_dict = dist_autograd.get_gradients(context_id) dist.barrier() self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight]) self.assertEqual( layer1.weight.grad, rpc.rpc_sync("worker0", DdpComparisonTest.get_remote_grads, args=(remote_layer1.module_rref, context_id)))
def test_restore_context_after_swtich_to_jit_thread(self): if self.rank != 0: return @torch.jit.script def forward_script( context_id: int, dst_worker_name: str, t1: Tensor, t2: Tensor ) -> Tuple[Tensor, Tensor]: res1_fut = rpc.rpc_async(dst_worker_name, local_add, (t1, t1)) res1 = res1_fut.wait() # After this, the script runs in a new JIT thread. loss1 = res1.sum() # SendRpcBackward is not attched, since DistAutogradContext is lost here. res2_fut = rpc.rpc_async(dst_worker_name, local_add, (t2, t2)) res2 = res2_fut.wait() loss2 = res2.sum() return loss1, loss2 with dist_autograd.context() as context_id: t1 = torch.ones((2, 3), requires_grad=True) t2 = torch.ones((2, 3), requires_grad=True) dst_worker_name = worker_name((self.rank + 1) % self.world_size) loss0, loss1 = forward_script(context_id, dst_worker_name, t1, t2) dist_autograd.backward(context_id, [loss0, loss1]) grad0, grad1 = dist_autograd.get_gradients(context_id) self.assertEqual(grad0, grad1)
def _test_backward_rref(self, callee, rref_owner): local_grads = None t1 = torch.ones((3, 3), requires_grad=True) t2 = torch.zeros((3, 3), requires_grad=True) local_ret = torch.add(t1, t2) local_ret.sum().backward() with dist_autograd.context() as context_id: rref_t1 = rpc.remote(rref_owner, _torch_ones, args=((3, 3), ), kwargs={"requires_grad": True}) if callee == rref_owner: rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2)) else: rref = rpc.remote(callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2)) ret = rref.to_here().wait() dist_autograd.backward([ret.sum()]) # verify grads on caller grads = dist_autograd.get_gradients(context_id) self.assertIn(t2, grads) self.assertEqual(grads[t2], t2.grad) # verify grads on rref owner self.assertTrue( rpc.rpc_sync(rref_owner, _compare_owner_value, args=(context_id, rref_t1, t1.grad)))
def test_context_cleanup_nested_rpc(self): # This is for the below `dist.barrier`. # For `RpcAgent` other than `ProcessGroupAgent`, # no `_default_pg` is initialized. if not dist.is_initialized(): dist.init_process_group( backend="gloo", init_method=self.init_method, rank=self.rank, world_size=self.world_size, ) dst_rank = (self.rank + 1) % self.world_size nested_dst_rank = (dst_rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) rpc.rpc_sync("worker{}".format(dst_rank), my_py_nested_call, args=(t1, t2, dst_rank, self.world_size, 0)) # tell next worker and nested next worker to store this context id # so we can verify that it has been cleaned up rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)) rpc.rpc_sync("worker{}".format(nested_dst_rank), _set_rpc_done, args=(context_id, 2)) dist.barrier() # let all nodes finish sending their RPCs success = _all_contexts_cleaned_up() self.assertTrue(success)
def _run_test_ddp_comparision(self, simulate_uneven_inputs=False): gLogger.info(f"Running trainer rank: {self.rank}") # Each trainer uses a different random seed. Otherwise, they are going # to have exactly the same initial model parameters, input, and # therefore grads. That means the grads will be the same before and # after DDP's all-reduce. torch.manual_seed(self.rank) dist.init_process_group( backend="gloo", init_method="file://{}".format(self.file_name), world_size=self.world_size, rank=self.rank, ) net = nn.Linear(2, 3) ddp_net = DistributedDataParallel(net) # Odd ranks join early if simulate_uneven_inputs. num_inputs = 1 if simulate_uneven_inputs: if self.rank % 2 == 0: num_inputs += 2 inputs_list = [torch.rand((3, 2)) for _ in range(num_inputs)] if simulate_uneven_inputs: gLogger.info( f"Rank {self.rank} training with {len(inputs_list)} inputs.") # Use distributed autograd. The gradients will be in RPC context map. grads_dict = {} with ddp_net.join(simulate_uneven_inputs): for i, inputs in enumerate(inputs_list): with dist_autograd.context() as context_id: loss = ddp_net(inputs).norm() dist_autograd.backward(context_id, [loss]) grads_dict = dist_autograd.get_gradients(context_id) gLogger.info( f"Trainer #{self.rank} got grad dict: {grads_dict}") # Use local autograd. The gradients will be in each variable's '.grad'. ddp_net.zero_grad() loss = ddp_net(inputs).norm() loss.backward() # The gradients should be the same for param in net.parameters(): self.assertTrue( param in grads_dict, msg= f"Param {param} is not in dist_auto grad dict {grads_dict} for iteration {i}", ) self.assertEqual( grads_dict[param], param.grad, msg= f"The grads for param {param} are different under local " f"and dist autograd: {param.grad} \n---\n {grads_dict[param]} for iteration {i}", ) dist.destroy_process_group()
def test_error_in_context(self): with dist_autograd.context() as context_id: t1 = torch.rand(3, 3, requires_grad=True) t2 = torch.rand(6, 6, requires_grad=True) with self.assertRaises(RuntimeError): # This should throw an error since matrix sizes don't match. rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.matmul, args=(t1, t2))
def test_graph_for_py_nested_call(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) nest_dst_rank = (dst_rank + 1) % self.world_size ret = rpc.rpc_sync("worker{}".format(dst_rank), my_py_nested_call, args=(t1, t2, dst_rank, self.world_size, 1)) for rd in [1, 2, 3]: rpc.rpc_sync("worker{}".format( (self.rank + rd) % self.world_size), _set_rpc_done, args=(context_id, rd)) # For self.rank, it has 4 graphs to verify # One is for current context id when this rank send first rpc call. # Second one is for prev context id when this rank make 1st nested # call. # Third one is for prev prev context id when this rank make # 2nd nested call. # Last one is for prev prev prev context id when this rank # execute the torch.add() operator. # Verify first graph for current context id. ctx = dist_autograd._current_context() self.assertEqual(context_id, ctx._context_id()) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) recv_functions = ctx._recv_functions() self.assertEqual(1, len(recv_functions)) self._verify_graph_for_first_rpc_call( list(send_functions.values())[0], list(recv_functions.values())[0], t1, t2, ret) # Verify second graph for 1st nested call. self._check_rpc_done(1) ctx = dist_autograd._retrieve_context(ctx_ids[1]) self._verify_graph_for_nested_rpc_call(ctx) # Verify third graph for 2nd nested call. self._check_rpc_done(2) ctx = dist_autograd._retrieve_context(ctx_ids[2]) self._verify_graph_for_nested_rpc_call(ctx) # verify last graph for rpc call execution. self._check_rpc_done(3) ctx = dist_autograd._retrieve_context(ctx_ids[3]) send_functions = ctx._send_functions() self.assertEqual(1, len(send_functions)) self._verify_graph_for_rpc_call_exec( list(send_functions.values())[0]) # this barrier is needed so one worker does not clean up their # autograd context before another worker tries to access it. dist.barrier()