Ejemplo n.º 1
0
 def test_nested_context(self):
     with dist_autograd.context() as context_id:
         # Nested contexts not supported.
         with self.assertRaisesRegex(
                 RuntimeError,
                 "Already have an autograd context id for this thread"):
             with dist_autograd.context() as context_id:
                 pass
Ejemplo n.º 2
0
def run_master():

    # put the two model parts on worker1 and worker2 respectively
    model = DistResNet(["worker1", "worker2", "worker3", "worker4"])
    loss_fn = nn.MSELoss()
    opt = DistributedOptimizer(
        optim.SGD,
        model.parameter_rrefs(),
        lr=0.05,
    )

    one_hot_indices = torch.LongTensor(batch_size) \
                           .random_(0, num_classes) \
                           .view(batch_size, 1)

    for i in range(num_batches):
        print(f"Processing batch {i}")
        # generate random inputs and labels
        inputs = torch.randn(batch_size, 3, image_w, image_h)
        labels = torch.zeros(batch_size, num_classes) \
                      .scatter_(1, one_hot_indices, 1)

        # The distributed autograd context is the dedicated scope for the
        # distributed backward pass to store gradients, which can later be
        # retrieved using the context_id by the distributed optimizer.
        with dist_autograd.context() as context_id:
            outputs = model(inputs)
            dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
            opt.step(context_id)
Ejemplo n.º 3
0
    def test_worker_ids_recorded(self):
        dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
        with dist_autograd.context() as context_id:
            # if no tensors require grad, we do not add the send functions, so
            # no worker ids should be recorded.
            t1 = torch.ones(3, 3, requires_grad=False)
            t2 = torch.zeros(3, 3, requires_grad=False)
            for dst_rank in dst_ranks:
                ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
                rpc.rpc_sync(
                    "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)
                )
            # no worker ids should be recorded.
            ctx = dist_autograd._current_context()
            worker_ids = ctx._known_worker_ids()
            self.assertEqual(len(worker_ids), 0)

            # worker_ids should be recorded when tensors do require grad
            t1.requires_grad = True
            t2.requires_grad = True
            for dst_rank in dst_ranks:
                ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
                rpc.rpc_sync(
                    "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)
                )
            # all worker_ids in dst_ranks should be recorded.
            worker_ids = ctx._known_worker_ids()
            self.assertEqual(len(worker_ids), len(dst_ranks))
            self.assertEqual(set(worker_ids), dst_ranks)
Ejemplo n.º 4
0
    def test_rpc_complex_args(self):
        with dist_autograd.context() as context_id:
            num_tensors = 10
            tensors = []
            for i in range(num_tensors):
                tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0)))
            ret = rpc.rpc_sync(
                "worker{}".format(self._next_rank()), torch.stack, args=(tensors,)
            )
            self.assertEqual(torch.stack(tensors), ret)

            # Verify appropriate tensors have been attached the autograd graph.
            next_funcs = list(
                dist_autograd._current_context()._send_functions().values()
            )[0].next_functions
            idx = 0
            for i in range(num_tensors):
                if i % 2 == 0:
                    self.assertEqual(
                        "torch::autograd::AccumulateGrad", next_funcs[i][0].name()
                    )
                    self.assertEqual(tensors[i], next_funcs[i][0].variable)
                else:
                    self.assertIsNone(next_funcs[i][0])

            # Verify that the worker id has been recorded in the context
            ctx = dist_autograd._current_context()
            worker_ids = ctx._known_worker_ids()
            self.assertEqual(len(worker_ids), 1)
            dst_rank = (self.rank + 1) % self.world_size
            self.assertEqual(worker_ids[0], dst_rank)
Ejemplo n.º 5
0
    def test_graph_for_py_nested_call_itself(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            ret = rpc.rpc_sync("worker{}".format(dst_rank),
                               my_py_nested_call,
                               args=(t1, t2, (self.rank - 1 + self.world_size) % self.world_size, self.world_size, 0))
            rpc.rpc_sync("worker{}".format((self.rank + 1) % self.world_size),
                         _set_rpc_done, args=(context_id, 1))

            # For self.rank, it has 2 graphs to verify.
            # One is for current context id when this rank send first rpc
            # call and execute the torch.add() operator.
            # Another one is for prev context id when this rank make
            # nested call.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(2, len(send_functions))
            recv_functions = ctx._recv_functions()
            self.assertEqual(2, len(recv_functions))
            self._verify_graph_for_first_rpc_call(list(send_functions.values())[0],
                                                  list(recv_functions.values())[1],
                                                  t1, t2, ret)
            self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1])

            # Verify two pairs of send and recv functions for nested
            # call
            self._check_rpc_done(1)
            ctx = dist_autograd._retrieve_context(ctx_ids[1])
            self._verify_graph_for_nested_rpc_call(ctx)
            # this barrier is needed so one worker does not clean up their
            # autograd context before another worker tries to access it.
            dist.barrier()
Ejemplo n.º 6
0
def train(model, optimizer, epoch, data, train_loader):
    model.train()

    pbar = tqdm(total=int(data.train_mask.sum()))
    pbar.set_description(f'Epoch {epoch:02d}')
    x = data.x
    y = data.y.squeeze()
    total_loss = total_correct = 0
    for batch_size, n_id, adjs in train_loader:
        # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
        with dist_autograd.context() as context_id:
            adjs = [adj for adj in adjs]

            out = model(x[n_id], adjs)
            loss = F.nll_loss(out, y[n_id[:batch_size]])
            dist_autograd.backward(context_id, [loss])
            optimizer.step(context_id)

            total_loss += float(loss)
            total_correct += int(
                out.argmax(dim=-1).eq(y[n_id[:batch_size]]).sum())
            pbar.update(batch_size)

    pbar.close()

    loss = total_loss / len(train_loader)
    approx_acc = total_correct / int(data.train_mask.sum())

    return loss, approx_acc
Ejemplo n.º 7
0
    def test_backward_invalid_args(self):
        with dist_autograd.context() as context_id:

            with self.assertRaisesRegex(TypeError,
                                        "incompatible function arguments"):
                dist_autograd.backward(None)

            with self.assertRaisesRegex(
                    RuntimeError,
                    "No tensors provided for gradient computation"):
                dist_autograd.backward([])

            with self.assertRaisesRegex(RuntimeError,
                                        "requires_grad not set on"):
                t = torch.rand(3, 3)
                dist_autograd.backward([t])

            with self.assertRaisesRegex(
                    RuntimeError,
                    "is not a scalar, all roots need to be scalar"):
                t = torch.rand(3, 3, requires_grad=True)
                dist_autograd.backward([t])

            with self.assertRaisesRegex(
                    RuntimeError, "does not have a valid gradient function"):
                t = torch.rand(1, requires_grad=True)
                dist_autograd.backward([t])
Ejemplo n.º 8
0
    def test_backward_autograd_engine_error(self):
        with dist_autograd.context() as context_id:
            t1 = torch.rand((3, 3), requires_grad=True)
            t2 = torch.rand((3, 3), requires_grad=True)

            # Perform some ops before error simulation.
            tmp = (t1 + t2) * (t1 + t2)
            t3 = SimulateBackwardError.apply(tmp)

            # Run multiple round trips across different nodes and verify the
            # original node receives an error thrown on a node deep in the chain.
            val = rpc.rpc_sync('worker{}'.format(self._next_rank()),
                               torch.add,
                               args=(t2, t3))
            val = rpc.rpc_sync('worker{}'.format(self._next_rank()),
                               torch.mul,
                               args=(val, t2))
            val = rpc.rpc_sync('worker{}'.format(self._next_rank()),
                               torch.matmul,
                               args=(val, t2))
            val = rpc.rpc_sync('worker{}'.format(self._next_rank()),
                               torch.div,
                               args=(val, t2))

            with self.assertRaises(RuntimeError):
                # Run backwards, and validate we receive an error.
                dist_autograd.backward([val.sum()])
Ejemplo n.º 9
0
def run_master(world_size, batch_size, microbatch_size):
    model = [
        RemoteModuleParams(nn.Linear, (784, 100), {}),
        RemoteModuleParams(nn.ReLU, (), {}),
        RemoteModuleParams(nn.Linear, (100, 100), {}),
        RemoteModuleParams(nn.ReLU, (), {}),
        RemoteModuleParams(nn.Linear, (100, 10), {}),
        RemoteModuleParams(nn.ReLU, (), {})
    ]
    torch.random.manual_seed(3)
    loss_fn = DistributedLoss(nn.CrossEntropyLoss)
    workers = [f"worker{i}/cpu" for i in range(0, world_size)]
    chunks = ceil(batch_size / microbatch_size)
    pipe = create_sequence_pipeline(model, [2] * world_size,
                                    workers,
                                    chunks=chunks)
    opt = DistributedOptimizer(
        optim.SGD,
        pipe.parameter_rrefs(),
        lr=0.05,
    )

    trainloader, testloader = get_data(batch_size)
    for i, (inputs, labels) in enumerate(trainloader):
        if i % 100 == 0:
            logging.info(f"Processing batch {i}")
            # evaluate(pipe, testloader)

        with dist_autograd.context() as context_id:
            outputs: RRef = pipe(inputs)
            loss = loss_fn(outputs, RRef(labels))  # block
            loss.backward(context_id)  # will run on last worker
            opt.step(context_id)
    evaluate(pipe, testloader)
Ejemplo n.º 10
0
def auto_graph_extract(devices):
    from fairscale.experimental.nn.distributed_pipeline.trace import make_graph

    device = devices[0].split("/")[1]
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
    x = torch.randn(8, 4).to(device)

    # create model
    model = nn.Sequential(
        RemoteModule(devices[0], nn.Linear, (4, 4), {}),
        ShardedLinearLayer(devices[0], devices, devices[1]),
        RemoteModule(devices[0], nn.Linear, (4, 4), {}),
    )
    graph = make_graph(model)
    pipe = DistributedPipeline(graph, chunks=4)
    partitions = extract_partitions(graph, pipe)
    assert [[0, 1], [2], [3], [4], [5]] == partitions, f"partitions={partitions}"
    parameter_rrefs = pipe.parameter_rrefs()
    assert len(parameter_rrefs) == 8
    opt = DistributedOptimizer(
        torch.optim.SGD,
        parameter_rrefs,
        lr=0.05,
    )
    losses = []
    for i in range(2):
        with dist_autograd.context() as context_id:
            y = pipe(x)
            loss = criterion(y, rpc.RRef(x))
            losses.append(loss)
            loss.backward(context_id)
            opt.step(context_id)
    losses = [l.to_here() for l in losses]
    assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
Ejemplo n.º 11
0
def update(devices):
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
    x = torch.randn(8, 4)
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model,
                            balance=[1, 1],
                            chunks=4,
                            devices=devices[:2])
    params = pipe.parameter_rrefs()
    opt = DistributedOptimizer(
        torch.optim.SGD,
        pipe.parameter_rrefs(),
        lr=0.05,
    )
    losses = []
    for i in range(2):
        with dist_autograd.context() as context_id:
            y = pipe(x)
            loss = criterion(y, rpc.RRef(x))
            losses.append(loss)
            loss.backward(context_id)
            opt.step(context_id)
    losses = [l.to_here() for l in losses]
    assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
Ejemplo n.º 12
0
    def test_backward_node_failure(self):
        with dist_autograd.context() as context_id:
            t1 = torch.rand((3, 3), requires_grad=True)
            t2 = torch.rand((3, 3), requires_grad=True)

            res = rpc.rpc_sync('worker{}'.format(self._next_rank()),
                               torch.add,
                               args=(t1, t2))

            # Wait for all RPCs to be done.
            dist.barrier()

            # Kill all odd rank nodes.
            if self.rank % 2 == 0:
                # Wait a bit for all other nodes to die.
                time.sleep(5)
                with self.assertRaisesRegex(
                        RuntimeError,
                        "Request aborted during client shutdown"):
                    # Run backwards, and validate we receive an error since all
                    # other nodes are dead.
                    dist_autograd.backward([res.sum()])
            else:
                # Exit all other nodes.
                pass
Ejemplo n.º 13
0
 def _train_step(self, formatter):
     self.model.train()
     total_loss = 0.
     total_correct = 0
     for batch_idx, (data, target) in enumerate(self.train_loader):
         with dist_autograd.context() as cid:
             output = self.model(data)
             target = target.long().squeeze(1)
             loss = self.loss_fn(output, target)
             total_loss += loss.item()
             correct = (torch.argmax(output, dim=1) == target).sum()
             total_correct += correct
             dist_autograd.backward([loss])
             # Ensure that dist autograd ran successfully and gradients were
             # returned.
             assert remote_method(
                 MasterNetwork.get_dist_gradients,
                 self.model.param_server_rref,
                 cid) != {}
             self.optimizer.step()
             print(formatter.train_progress_message(batch_idx=batch_idx, batches=len(self.train_loader),
                                                    training_examples=len(data), correct=correct,
                                                    loss=loss.item()))
     train_loss = total_loss / len(self.train_loader.dataset)
     train_acc = total_correct / len(self.train_loader.dataset)
     return train_loss, train_acc
Ejemplo n.º 14
0
def train(rrefs, kwargs):
    model = TGCN(rrefs,
                 kwargs['h_size'],
                 kwargs['z_size'],
                 gru_hidden_units=kwargs['n_gru'])

    opt = DistributedOptimizer(Adam, model.parameter_rrefs(), lr=kwargs['lr'])

    times = []
    best = (None, 0)
    no_progress = 0
    for e in range(kwargs['epochs']):
        # Get loss and send backward
        model.train()
        with dist_autograd.context() as context_id:
            st = time.time()
            zs = model.forward(ld.LANL_Data.TRAIN)
            loss = model.loss_fn(zs,
                                 ld.LANL_Data.TRAIN,
                                 nratio=kwargs['nratio'])

            print("backward")
            dist_autograd.backward(context_id, [loss])

            print("step")
            opt.step(context_id)

            elapsed = time.time() - st
            times.append(elapsed)
            print('[%d] Loss %0.4f  %0.2fs' % (e, loss.item(), elapsed))

        # Get validation info to prevent overfitting
        model.eval()
        with torch.no_grad():
            zs = model.forward(ld.LANL_Data.TRAIN, no_grad=True)
            v_loss = model.loss_fn(zs, ld.LANL_Data.VAL).item()

            print("\t Val loss: %0.4f" % v_loss)

            if v_loss > best[1]:
                best = (model.save_states(), v_loss)
            else:
                if e >= kwargs['min']:
                    no_progress += 1

            if no_progress == kwargs['patience']:
                print("Early stopping!")
                break

    model.load_states(best[0][0], best[0][1])
    zs, h0 = model(ld.LANL_Data.TEST, include_h=True)

    states = {'gcn': best[0][0], 'rnn': best[0][1]}
    f = open('model_save.pkl', 'wb+')
    pickle.dump(states, f, protocol=pickle.HIGHEST_PROTOCOL)

    print("Exiting train loop")
    print("Avg TPE: %0.4fs" % (sum(times) / len(times)))

    return model, zs[-1], h0
Ejemplo n.º 15
0
def update(devices):
    device = devices[0].split("/")[1]
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
    x = torch.randn(8, 4).to(device)
    model = [
        RemoteModuleParams(nn.Linear, (4, 4), {}),
        RemoteModuleParams(nn.ReLU, (), {})
    ]
    pipe = create_sequence_pipeline(model,
                                    balance=[1, 1],
                                    chunks=4,
                                    devices=devices[:2])
    opt = DistributedOptimizer(
        torch.optim.SGD,
        pipe.parameter_rrefs(),
        lr=0.05,
    )
    losses = []
    for i in range(2):
        with dist_autograd.context() as context_id:
            y = pipe(x)
            loss = criterion(y, rpc.RRef(x))
            losses.append(loss)
            loss.backward(context_id)
            opt.step(context_id)
    losses = [l.to_here() for l in losses]
    assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
Ejemplo n.º 16
0
    def test_dist_optim(self):
        # local version
        module1 = MyModule()
        module2 = MyModule()
        params = [module1.get_w(), module2.get_w()]
        local_optim = optim.SGD(params, lr=0.05)

        old_w1 = module1.w.clone().detach()
        old_w2 = module2.w.clone().detach()

        torch.manual_seed(0)
        t1 = torch.rand((3, 3), requires_grad=True)
        t2 = torch.rand((3, 3), requires_grad=True)
        output1 = module1.forward(t2)
        output2 = module2.forward(output1)
        loss = torch.add(output2, t1).sum()

        loss.backward()
        local_optim.step()

        # distributed version
        owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
        owner2 = "worker%d" % ((self.rank + 2) % self.world_size)

        remote_module1 = rpc.remote(owner1, MyModule)
        remote_module2 = rpc.remote(owner2, MyModule)
        remote_param1 = remote_method(MyModule.get_w, remote_module1)
        remote_param2 = remote_method(MyModule.get_w, remote_module2)

        old_w1_remote = remote_param1.to_here()

        # sanity check: local and remote initial weights should match
        self.assertEqual(old_w1, remote_param1.to_here())
        self.assertEqual(old_w2, remote_param2.to_here())

        dist_optim = DistributedOptimizer(optim.SGD,
                                          [remote_param1, remote_param2],
                                          lr=0.05)

        with dist_autograd.context():
            torch.manual_seed(0)
            t1 = torch.rand((3, 3), requires_grad=True)
            t2 = torch.rand((3, 3), requires_grad=True)
            output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
            output2 = rpc_async_method(MyModule.forward, remote_module2,
                                       output1.wait())
            loss = torch.add(output2.wait(), t1)

            dist_autograd.backward([loss.sum()])
            dist_optim.step()

            new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait()
            new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait()

            # ensure optimizer changed weights
            self.assertNotEqual(old_w1, new_w1)
            self.assertNotEqual(old_w2, new_w2)
            # ensure local equals remote
            self.assertEqual(new_w1, module1.get_w())
            self.assertEqual(new_w2, module2.get_w())
Ejemplo n.º 17
0
def training_loop():
    model = RNN('ps', 3, 10, 1)
    X_tr, y_tr = gen_toy_data()
    X_te, y_te = gen_toy_data()

    opt = DistributedOptimizer(Adam, model.parameter_rrefs(), lr=0.01)

    loss_fn = nn.MSELoss()

    for e in range(100):
        with dist_autograd.context() as context_id:
            y_hat = model(X_tr)
            loss = loss_fn(y_hat, y_tr)

            dist_autograd.backward(context_id, [loss])
            opt.step(context_id)
            # No need to zero grad because it's blown
            # away every step by the dist API

        print("[%d] Loss: %0.4f" % (e, loss.item()))

    y_hat = model(X_te)
    y_hat[y_hat < 0.5] = 0
    y_hat[y_hat >= 0.5] = 1

    correct = float((y_hat == y_te).sum().item())
    total = float(y_hat.size(1))
    print("Final accuracy: %d/%d = %0.4f" % (correct, total, correct / total))
Ejemplo n.º 18
0
def _run_trainer(rref_t1, t2, ps, rank_diff):
    with dist_autograd.context() as context_id:
        ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2))
        dist_autograd.backward([ret.sum()])
        # prevent deleting dist autograd context
        rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
        rpc.rpc_sync(ps, _check_rpc_done, args=(0, ))
Ejemplo n.º 19
0
def run_master(split_size):

    # put the two model parts on worker1 and worker2 respectively
    model = DistResNet50(split_size, ["worker1", "worker2"])
    loss_fn = nn.MSELoss()
    opt = DistributedOptimizer(
        optim.SGD,
        model.parameter_rrefs(),
        lr=0.05,
    )

    one_hot_indices = torch.LongTensor(batch_size) \
                           .random_(0, num_classes) \
                           .view(batch_size, 1)

    for i in range(num_batches):
        print(f"Processing batch {i}")
        # generate random inputs and labels
        inputs = torch.randn(batch_size, 3, image_w, image_h)
        labels = torch.zeros(batch_size, num_classes) \
                      .scatter_(1, one_hot_indices, 1)

        with dist_autograd.context() as context_id:
            outputs = model(inputs)
            dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
            opt.step(context_id)
Ejemplo n.º 20
0
    def _test_dist_optim_none_grads(self, optim_cls, *args, **kwargs):
        # local version
        module1 = MyModule()
        module2 = MyModule(requires_grad=False)
        params = [module1.get_w(), module2.get_w()]
        local_optim = optim_cls(params, *args, **kwargs)

        old_w1 = module1.w.clone().detach()
        old_w2 = module2.w.clone().detach()

        g_cpu = torch.Generator()
        g_cpu.manual_seed(0)
        t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
        t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
        output1 = module1.forward(t2)
        output2 = module2.forward(output1)
        loss = torch.add(output2, t1).sum()

        loss.backward()
        local_optim.step()

        # distributed version
        owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
        owner2 = "worker%d" % ((self.rank + 2) % self.world_size)

        remote_module1 = rpc.remote(owner1, MyModule)
        remote_module2 = rpc.remote(owner2, MyModule, args=(False, ))
        remote_param1 = remote_module1.remote().get_w()
        remote_param2 = remote_module2.remote().get_w()

        # sanity check: local and remote initial weights should match
        self.assertEqual(old_w1, remote_param1.to_here())
        self.assertEqual(old_w2, remote_param2.to_here())

        dist_optim = DistributedOptimizer(optim_cls,
                                          [remote_param1, remote_param2],
                                          *args, **kwargs)

        with dist_autograd.context() as context_id:
            g_cpu.manual_seed(0)
            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
            output1 = remote_module1.rpc_async().forward(t2)
            output2 = remote_module2.rpc_async().forward(output1.wait())
            loss = torch.add(output2.wait(), t1)

            dist_autograd.backward(context_id, [loss.sum()])
            dist_optim.step(context_id)

            new_w1 = remote_module1.rpc_async().get_w().wait()
            new_w2 = remote_module2.rpc_async().get_w().wait()

            # ensure optimizer changed weights for w1
            self.assertNotEqual(old_w1, new_w1)

            # ensure optimizer not changed weights for w2
            self.assertEqual(old_w2, new_w2)
            # ensure local equals remote
            self.assertEqual(new_w1, module1.get_w())
            self.assertEqual(new_w2, module2.get_w())
Ejemplo n.º 21
0
    def test_dist_optim_exception(self):
        # distributed version
        owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
        owner2 = "worker%d" % ((self.rank + 2) % self.world_size)

        remote_module1 = rpc.remote(owner1, MyModule)
        remote_module2 = rpc.remote(owner2, MyModule)
        remote_param1 = remote_method(MyModule.get_w, remote_module1)
        remote_param2 = remote_method(MyModule.get_w, remote_module2)

        dist_optim = DistributedOptimizer(FailingOptimizer,
                                          [remote_param1, remote_param2])

        with dist_autograd.context() as context_id:
            g_cpu = torch.Generator()
            g_cpu.manual_seed(0)
            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
            output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
            output2 = rpc_async_method(MyModule.forward, remote_module2,
                                       output1.wait())
            loss = torch.add(output2.wait(), t1).sum()

            dist_autograd.backward(context_id, [loss])
            with self.assertRaisesRegex(Exception, "Error running optimizer"):
                dist_optim.step(context_id)
    def test_ddp_dist_autograd_sparse_grads(self):
        # Each trainer uses a different random seed. Otherwise, they are going
        # to have exactly the same initial model parameters, input, and
        # therefore grads. That means the grads will be the same before and
        # after DDP's all-reduce.
        torch.manual_seed(self.rank)
        dist.init_process_group(
            backend="gloo",
            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
            world_size=self.world_size,
            rank=self.rank,
        )

        model = nn.EmbeddingBag(10, 3, sparse=True)
        ddp_model = DistributedDataParallel(model)

        # Different inputs for each
        input = torch.LongTensor(10).random_(0, 10)
        offsets = torch.LongTensor([0, 4])

        # Run local.
        loss = ddp_model(input, offsets).sum()
        loss.backward()

        with dist_autograd.context() as context_id:
            loss = ddp_model(input, offsets).sum()
            dist_autograd.backward(context_id, [loss])
            grads_dict = dist_autograd.get_gradients(context_id)
            self.assertEqual(1, len(grads_dict))
            self.assertEqual(model.weight.grad, grads_dict[model.weight])
Ejemplo n.º 23
0
def run_training_loop(rank, num_gpus, train_loader, test_loader):
    # Runs the typical nueral network forward + backward + optimizer step, but
    # in a distributed fashion.
    net = TrainerNet(num_gpus=num_gpus)
    # Build DistributedOptmizer.
    param_rrefs = net.get_global_param_rrefs()
    opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03)
    for i, (data, target) in enumerate(train_loader):
        with dist_autograd.context() as cid:
            model_output = net(data)
            target = target.to(model_output.device)
            loss = F.nll_loss(model_output, target)
            if i % 5 == 0:
                print(f"Rank {rank} training batch {i} loss {loss.item()}")
            dist_autograd.backward(cid, [loss])
            # Ensure that dist autograd ran successfully and gradients were
            # returned.
            assert remote_method(
                ParameterServer.get_dist_gradients,
                net.param_server_rref,
                cid) != {}
            opt.step(cid)

    print("Training complete!")
    print("Getting accuracy....")
    get_accuracy(test_loader, net)
Ejemplo n.º 24
0
    def test_ddp_dist_autograd_local_vs_remote(self):
        # Each trainer uses a different random seed. Otherwise, they are going
        # to have exactly the same initial model parameters, input, and
        # therefore grads. That means the grads will be the same before and
        # after DDP's all-reduce.
        torch.manual_seed(self.rank)
        dist.init_process_group(backend="gloo",
                                init_method="file://{}".format(self.file_name),
                                world_size=self.world_size,
                                rank=self.rank)

        remote_layer1 = RemoteModule("worker0", nn.Linear, args=(10, 5, False))
        layer1 = nn.Linear(10, 5, False)
        # Start with the same parameters for remote and local
        layer1.weight = remote_layer1.module_rref.to_here().weight

        # Run local case.
        layer2 = nn.Linear(5, 1)
        inputs = torch.rand((10, 10))
        ddp_model = DistributedDataParallel(layer2)
        loss = ddp_model(layer1(inputs)).sum()
        loss.backward()

        # Run remote case.
        with dist_autograd.context() as context_id:
            loss = ddp_model(remote_layer1(inputs)).sum()
            dist_autograd.backward(context_id, [loss])
            grads_dict = dist_autograd.get_gradients(context_id)
            dist.barrier()
            self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
            self.assertEqual(
                layer1.weight.grad,
                rpc.rpc_sync("worker0",
                             DdpComparisonTest.get_remote_grads,
                             args=(remote_layer1.module_rref, context_id)))
Ejemplo n.º 25
0
    def test_restore_context_after_swtich_to_jit_thread(self):
        if self.rank != 0:
            return

        @torch.jit.script
        def forward_script(
            context_id: int, dst_worker_name: str, t1: Tensor, t2: Tensor
        ) -> Tuple[Tensor, Tensor]:
            res1_fut = rpc.rpc_async(dst_worker_name, local_add, (t1, t1))
            res1 = res1_fut.wait()  # After this, the script runs in a new JIT thread.
            loss1 = res1.sum()

            # SendRpcBackward is not attched, since DistAutogradContext is lost here.
            res2_fut = rpc.rpc_async(dst_worker_name, local_add, (t2, t2))
            res2 = res2_fut.wait()
            loss2 = res2.sum()

            return loss1, loss2

        with dist_autograd.context() as context_id:
            t1 = torch.ones((2, 3), requires_grad=True)
            t2 = torch.ones((2, 3), requires_grad=True)
            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
            loss0, loss1 = forward_script(context_id, dst_worker_name, t1, t2)
            dist_autograd.backward(context_id, [loss0, loss1])
            grad0, grad1 = dist_autograd.get_gradients(context_id)
            self.assertEqual(grad0, grad1)
Ejemplo n.º 26
0
    def _test_backward_rref(self, callee, rref_owner):
        local_grads = None
        t1 = torch.ones((3, 3), requires_grad=True)
        t2 = torch.zeros((3, 3), requires_grad=True)

        local_ret = torch.add(t1, t2)
        local_ret.sum().backward()
        with dist_autograd.context() as context_id:
            rref_t1 = rpc.remote(rref_owner,
                                 _torch_ones,
                                 args=((3, 3), ),
                                 kwargs={"requires_grad": True})

            if callee == rref_owner:
                rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2))
            else:
                rref = rpc.remote(callee,
                                  my_nested_rref_add,
                                  args=(rref_owner, rref_t1, t2))
            ret = rref.to_here().wait()
            dist_autograd.backward([ret.sum()])

            # verify grads on caller
            grads = dist_autograd.get_gradients(context_id)
            self.assertIn(t2, grads)
            self.assertEqual(grads[t2], t2.grad)

            # verify grads on rref owner
            self.assertTrue(
                rpc.rpc_sync(rref_owner,
                             _compare_owner_value,
                             args=(context_id, rref_t1, t1.grad)))
Ejemplo n.º 27
0
    def test_context_cleanup_nested_rpc(self):
        # This is for the below `dist.barrier`.
        # For `RpcAgent` other than `ProcessGroupAgent`,
        # no `_default_pg` is initialized.
        if not dist.is_initialized():
            dist.init_process_group(
                backend="gloo",
                init_method=self.init_method,
                rank=self.rank,
                world_size=self.world_size,
            )

        dst_rank = (self.rank + 1) % self.world_size
        nested_dst_rank = (dst_rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            rpc.rpc_sync("worker{}".format(dst_rank),
                         my_py_nested_call,
                         args=(t1, t2, dst_rank, self.world_size, 0))
            # tell next worker and nested next worker to store this context id
            # so we can verify that it has been cleaned up
            rpc.rpc_sync("worker{}".format(dst_rank),
                         _set_rpc_done,
                         args=(context_id, 1))
            rpc.rpc_sync("worker{}".format(nested_dst_rank),
                         _set_rpc_done,
                         args=(context_id, 2))
        dist.barrier()  # let all nodes finish sending their RPCs
        success = _all_contexts_cleaned_up()
        self.assertTrue(success)
Ejemplo n.º 28
0
    def _run_test_ddp_comparision(self, simulate_uneven_inputs=False):
        gLogger.info(f"Running trainer rank: {self.rank}")
        # Each trainer uses a different random seed. Otherwise, they are going
        # to have exactly the same initial model parameters, input, and
        # therefore grads. That means the grads will be the same before and
        # after DDP's all-reduce.
        torch.manual_seed(self.rank)
        dist.init_process_group(
            backend="gloo",
            init_method="file://{}".format(self.file_name),
            world_size=self.world_size,
            rank=self.rank,
        )
        net = nn.Linear(2, 3)
        ddp_net = DistributedDataParallel(net)

        # Odd ranks join early if simulate_uneven_inputs.
        num_inputs = 1
        if simulate_uneven_inputs:
            if self.rank % 2 == 0:
                num_inputs += 2
        inputs_list = [torch.rand((3, 2)) for _ in range(num_inputs)]

        if simulate_uneven_inputs:
            gLogger.info(
                f"Rank {self.rank} training with {len(inputs_list)} inputs.")

        # Use distributed autograd. The gradients will be in RPC context map.
        grads_dict = {}
        with ddp_net.join(simulate_uneven_inputs):
            for i, inputs in enumerate(inputs_list):
                with dist_autograd.context() as context_id:
                    loss = ddp_net(inputs).norm()
                    dist_autograd.backward(context_id, [loss])
                    grads_dict = dist_autograd.get_gradients(context_id)
                gLogger.info(
                    f"Trainer #{self.rank} got grad dict: {grads_dict}")

                # Use local autograd. The gradients will be in each variable's '.grad'.
                ddp_net.zero_grad()
                loss = ddp_net(inputs).norm()
                loss.backward()

                # The gradients should be the same
                for param in net.parameters():
                    self.assertTrue(
                        param in grads_dict,
                        msg=
                        f"Param {param} is not in dist_auto grad dict {grads_dict} for iteration {i}",
                    )
                    self.assertEqual(
                        grads_dict[param],
                        param.grad,
                        msg=
                        f"The grads for param {param} are different under local "
                        f"and dist autograd: {param.grad} \n---\n {grads_dict[param]} for iteration {i}",
                    )
        dist.destroy_process_group()
Ejemplo n.º 29
0
    def test_error_in_context(self):
        with dist_autograd.context() as context_id:
            t1 = torch.rand(3, 3, requires_grad=True)
            t2 = torch.rand(6, 6, requires_grad=True)

            with self.assertRaises(RuntimeError):
                # This should throw an error since matrix sizes don't match.
                rpc.rpc_sync('worker{}'.format(self._next_rank()),
                             torch.matmul,
                             args=(t1, t2))
Ejemplo n.º 30
0
    def test_graph_for_py_nested_call(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            nest_dst_rank = (dst_rank + 1) % self.world_size
            ret = rpc.rpc_sync("worker{}".format(dst_rank),
                               my_py_nested_call,
                               args=(t1, t2, dst_rank, self.world_size, 1))
            for rd in [1, 2, 3]:
                rpc.rpc_sync("worker{}".format(
                    (self.rank + rd) % self.world_size),
                             _set_rpc_done,
                             args=(context_id, rd))

            # For self.rank, it has 4 graphs to verify
            # One is for current context id when this rank send first rpc call.
            # Second one is for prev context id when this rank make 1st nested
            # call.
            # Third one is for prev prev context id when this rank make
            # 2nd nested call.
            # Last one is for prev prev prev context id when this rank
            # execute the torch.add() operator.

            # Verify first graph for current context id.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            recv_functions = ctx._recv_functions()
            self.assertEqual(1, len(recv_functions))
            self._verify_graph_for_first_rpc_call(
                list(send_functions.values())[0],
                list(recv_functions.values())[0], t1, t2, ret)

            # Verify second graph for 1st nested call.
            self._check_rpc_done(1)
            ctx = dist_autograd._retrieve_context(ctx_ids[1])
            self._verify_graph_for_nested_rpc_call(ctx)

            # Verify third graph for 2nd nested call.
            self._check_rpc_done(2)
            ctx = dist_autograd._retrieve_context(ctx_ids[2])
            self._verify_graph_for_nested_rpc_call(ctx)

            # verify last graph for rpc call execution.
            self._check_rpc_done(3)
            ctx = dist_autograd._retrieve_context(ctx_ids[3])
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            self._verify_graph_for_rpc_call_exec(
                list(send_functions.values())[0])
            # this barrier is needed so one worker does not clean up their
            # autograd context before another worker tries to access it.
            dist.barrier()