def auto_graph_extract(devices):
    from fairscale.experimental.nn.distributed_pipeline.trace import make_graph

    device = devices[0].split("/")[1]
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
    x = torch.randn(8, 4).to(device)

    # create model
    model = nn.Sequential(
        RemoteModule(devices[0], nn.Linear, (4, 4), {}),
        ShardedLinearLayer(devices[0], devices, devices[1]),
        RemoteModule(devices[0], nn.Linear, (4, 4), {}),
    )
    graph = make_graph(model)
    pipe = DistributedPipeline(graph, chunks=4)
    partitions = extract_partitions(graph, pipe)
    assert [[0, 1], [2], [3], [4], [5]] == partitions, f"partitions={partitions}"
    parameter_rrefs = pipe.parameter_rrefs()
    assert len(parameter_rrefs) == 8
    opt = DistributedOptimizer(
        torch.optim.SGD,
        parameter_rrefs,
        lr=0.05,
    )
    losses = []
    for i in range(2):
        with dist_autograd.context() as context_id:
            y = pipe(x)
            loss = criterion(y, rpc.RRef(x))
            losses.append(loss)
            loss.backward(context_id)
            opt.step(context_id)
    losses = [l.to_here() for l in losses]
    assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
Example #2
0
 def __init__(self, input_device, shard_devices, output_device):
     super().__init__()
     self.split = RemoteModule(input_device, SplitTensors, (), {})
     self.linear_layers_2 = nn.ModuleList([
         RemoteModule(shard_devices[0], nn.Linear, (2, 2), {}),
         RemoteModule(shard_devices[1], nn.Linear, (2, 2), {}),
     ])
     self.concatenate = RemoteModule(output_device, ConcatenateTensors, ())
Example #3
0
    def test_ddp_dist_autograd_local_vs_remote_gpu(self):
        # Each trainer uses a different random seed. Otherwise, they are going
        # to have exactly the same initial model parameters, input, and
        # therefore grads. That means the grads will be the same before and
        # after DDP's all-reduce.
        torch.manual_seed(self.rank)
        dist.init_process_group(backend="gloo",
                                init_method="file://{}".format(self.file_name),
                                world_size=self.world_size,
                                rank=self.rank)

        remote_layer1 = RemoteModule("worker0", nn.Linear, args=(10, 7, False))
        layer1 = nn.Linear(10, 7, False)
        # Start with the same parameters for remote and local
        layer1.weight = remote_layer1.module_rref.to_here().weight

        layer2 = nn.Linear(7, 5).cuda(self.rank)
        ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank])

        remote_layer3 = RemoteModule("worker0", nn.Linear, args=(5, 3, False))
        layer3 = nn.Linear(5, 3, False)
        # Start with the same parameters for remote and local
        layer3.weight = remote_layer3.module_rref.to_here().weight

        layer4 = nn.Linear(3, 1).cuda(self.rank)
        ddp_layer4 = DistributedDataParallel(layer4, device_ids=[self.rank])

        # Run local case.
        inputs = torch.rand((10, 10))
        loss = ddp_layer4(
            layer3(ddp_layer2(layer1(inputs).cuda(self.rank)).cpu()).cuda(
                self.rank)).sum()
        loss.backward()

        # Run remote case.
        with dist_autograd.context() as context_id:
            loss = ddp_layer4(
                remote_layer3(
                    ddp_layer2(remote_layer1(inputs).cuda(
                        self.rank)).cpu()).cuda(self.rank)).sum()
            dist_autograd.backward(context_id, [loss])
            grads_dict = dist_autograd.get_gradients(context_id)
            dist.barrier()
            self.assertEqual(
                layer1.weight.grad,
                rpc.rpc_sync("worker0",
                             DdpComparisonTest.get_remote_grads,
                             args=(remote_layer1.module_rref, context_id)))
            self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
            self.assertEqual(
                layer3.weight.grad,
                rpc.rpc_sync("worker0",
                             DdpComparisonTest.get_remote_grads,
                             args=(remote_layer3.module_rref, context_id)))
            self.assertEqual(layer4.weight.grad, grads_dict[layer4.weight])
Example #4
0
    def test_bad_module(self):
        if self.rank != 0:
            return
        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
        args = (1,)
        kwargs = dict(first_kwarg=2)

        with self.assertRaisesRegex(
            ValueError,
            r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
        ):
            RemoteModule(dst_worker_name, BadModule, args, kwargs)

        with self.assertRaisesRegex(
            ValueError,
            r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
        ):
            RemoteModule(dst_worker_name, BadModule, args, kwargs)
def multi_input_multi_output_layers(devices):
    device = devices[0].split("/")[1]
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
    x = torch.randn(8, 4).to(device)

    #                                / ->linear_layer_2_1
    # input -> linear_layer1 -> split                     ->concatenate
    #                                \ ->linear_layer_2_2

    linear_layer_1 = RemoteModule(devices[0], nn.Linear, (4, 4), {})
    split = RemoteModule(devices[0], SplitTensors, (), {})
    linear_layers_2 = [
        RemoteModule(devices[0], nn.Linear, (2, 2), {}),
        RemoteModule(devices[1], nn.Linear, (2, 2), {}),
    ]
    concatenate = RemoteModule(devices[1], ConcatenateTensors, ())

    graph = PipelineModulesGraph()
    graph.add_sequence([linear_layer_1, split], [0], 2)
    for i, l in enumerate(linear_layers_2):
        graph.add_layer(l, [(split, i)])
    graph.add_layer(concatenate, linear_layers_2)

    pipe = DistributedPipeline(graph, chunks=4)
    assert [[0, 1], [2], [3], [4]] == extract_partitions(graph, pipe)
    parameter_rrefs = pipe.parameter_rrefs()
    assert len(parameter_rrefs) == 6
    opt = DistributedOptimizer(
        torch.optim.SGD,
        parameter_rrefs,
        lr=0.05,
    )
    losses = []
    for i in range(2):
        with dist_autograd.context() as context_id:
            y = pipe(x)
            loss = criterion(y, rpc.RRef(x))
            losses.append(loss)
            loss.backward(context_id)
            opt.step(context_id)
    losses = [l.to_here() for l in losses]
    assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
    def test_ddp_dist_autograd_local_vs_remote(self):
        # Each trainer uses a different random seed. Otherwise, they are going
        # to have exactly the same initial model parameters, input, and
        # therefore grads. That means the grads will be the same before and
        # after DDP's all-reduce.
        torch.manual_seed(self.rank)
        dist.init_process_group(
            backend="gloo",
            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
            world_size=self.world_size,
            rank=self.rank,
        )

        # Use two different remote device input string, w/ and w/o the default
        # device string "cpu", respectively.
        for remote_device in ["worker0/cpu", "worker0"]:
            remote_layer1 = RemoteModule(
                remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False)
            )
            layer1 = nn.Linear(10, 5, False)
            # Start with the same parameters for remote and local
            layer1.weight = remote_layer1.module_rref.to_here().weight

            # Run local case.
            layer2 = nn.Linear(5, 1)
            inputs = torch.rand((10, 10))
            ddp_model = DistributedDataParallel(layer2)
            loss = ddp_model(layer1(inputs)).sum()
            loss.backward()

            # Run remote case.
            with dist_autograd.context() as context_id:
                loss = ddp_model(remote_layer1(inputs)).sum()
                dist_autograd.backward(context_id, [loss])
                grads_dict = dist_autograd.get_gradients(context_id)
                dist.barrier()
                self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
                self.assertEqual(
                    layer1.weight.grad,
                    rpc.rpc_sync(
                        "worker0",
                        CommonDdpComparisonTest.get_remote_grads,
                        args=(remote_layer1.module_rref, context_id),
                    ),
                )
Example #7
0
    def _create_remote_module_iter(remote_device, modes=None):
        if modes is None:
            modes = ModuleCreationMode.__members__.values()

        args = (1,)
        kwargs = dict(first_kwarg=2)

        if ModuleCreationMode.MODULE_CTOR in modes:
            remote_module = RemoteModule(remote_device, MyModule, args, kwargs)
            yield remote_module

        if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes:
            remote_module = _RemoteModule(
                remote_device,
                create_scripted_module,
                args,
                kwargs,
                _module_interface_cls=MyModuleInterface,
            )
            scripted_remote_module = torch.jit.script(remote_module)
            yield scripted_remote_module
def create_sequence_pipeline(
    layers: List[RemoteModuleParams], balance: List[int], devices: List[str], **kwargs: Any
) -> DistributedPipeline:
    """A simple helper function to create a pipeline from list of pipeline-modules that run sequentially.
    Args:
        layers: list of modules. They should not be already assigned a remote-device.
        balance: a list of integers how layers should be paritioned. Sum of numbers in 'balance'
            should be equal to the number of layers.
        devices: specification of remote device for each partition. Should be of the same length
            as 'balance'.
    """
    remote_modules: List[RemoteModule] = []
    index = 0
    for num_layers, remote_device in zip(balance, devices):
        next_index = index + num_layers
        for li in range(index, next_index):
            remote_modules.append(RemoteModule(remote_device, **layers[li]._asdict()))
        index = next_index

    graph = PipelineModulesGraph()
    graph.add_sequence(remote_modules, [0])

    return DistributedPipeline(graph, **kwargs)
Example #9
0
def run_worker(rank, world_size):
    r"""
    A wrapper function that initializes RPC, calls the function, and shuts down
    RPC.
    """

    # We need to use different port numbers in TCP init_method for init_rpc and
    # init_process_group to avoid port conflicts.
    rpc_backend_options = TensorPipeRpcBackendOptions()
    rpc_backend_options.init_method = "tcp://localhost:29501"

    # Rank 2 is master, 3 is ps and 0 and 1 are trainers.
    if rank == 2:
        rpc.init_rpc(
            "master",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )

        remote_emb_module = RemoteModule(
            "ps",
            torch.nn.EmbeddingBag,
            args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
            kwargs={"mode": "sum"},
        )

        # Run the training loop on trainers.
        futs = []
        for trainer_rank in [0, 1]:
            trainer_name = "trainer{}".format(trainer_rank)
            fut = rpc.rpc_async(trainer_name,
                                _run_trainer,
                                args=(remote_emb_module, trainer_rank))
            futs.append(fut)

        # Wait for all training to finish.
        for fut in futs:
            fut.wait()
    elif rank <= 1:
        # Initialize process group for Distributed DataParallel on trainers.
        dist.init_process_group(backend="gloo",
                                rank=rank,
                                world_size=2,
                                init_method="tcp://localhost:29500")

        # Initialize RPC.
        trainer_name = "trainer{}".format(rank)
        rpc.init_rpc(
            trainer_name,
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )

        # Trainer just waits for RPCs from master.
    else:
        rpc.init_rpc(
            "ps",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )
        # parameter server do nothing
        pass

    # block until all rpcs finish
    rpc.shutdown()
Example #10
0
def create_remote_module_by_module_rref(remote_device, module_rref):
    return RemoteModule(remote_device=remote_device, module_rref=module_rref)