Exemple #1
0
    def forward(ctx, input, P_send, P_recv, preserve_batch,
                input_tensor_structure, output_tensor_structure, dtype):

        ctx.P_send = P_send
        ctx.P_recv = P_recv
        ctx.preserve_batch = preserve_batch
        ctx.input_tensor_structure = input_tensor_structure
        ctx.output_tensor_structure = output_tensor_structure
        ctx.dtype = dtype

        input_tensor_shape = input_tensor_structure[2]
        output_requires_grad = output_tensor_structure[0]
        output_tensor_shape = output_tensor_structure[2]

        # This allows all ranks to use the same exit path, so that we can be
        # sure that all requests have cleared.
        if preserve_batch:
            output = zero_volume_tensor(input.shape[0])
        else:
            output = zero_volume_tensor()

        requests = []

        # By design, the roots are always 0 in the cross-communicators
        # If I receive data (either from a remote worker or just from myself)
        # I need to reduce that data.  If I send and receive to myself, this
        # is OK, as the reduction accounts for the copy, unlike the broadcast
        # below.
        if P_send.active:
            reduced_data_send = np.zeros(input_tensor_shape, dtype=dtype)
            input_numpy = input.detach().numpy()
            req = P_send.comm.Ireduce(input_numpy,
                                      reduced_data_send,
                                      root=0,
                                      op=MPI.SUM)
            requests.append(req)

        # If I sent data in the forward, I have to receive it here.  mpi4py
        # does not allow aliasing of the input, so we have to make a copy of
        # nothing, unfortunately.
        if P_send != P_recv and P_recv.active:
            reduced_data_recv = np.zeros(output_tensor_shape, dtype=dtype)
            req = P_recv.comm.Ireduce(reduced_data_recv.copy(),
                                      reduced_data_recv,
                                      root=0,
                                      op=MPI.SUM)
            requests.append(req)

        MPI.Request.Waitall(requests)

        # If we had to receive data, we need to tensorify it.
        if P_recv.active:
            if P_send == P_recv:
                output = torch.tensor(reduced_data_send,
                                      requires_grad=output_requires_grad)
            else:
                output = torch.tensor(reduced_data_recv,
                                      requires_grad=output_requires_grad)

        return output
Exemple #2
0
def test_linear_adjoint_input(barrier_fence_fixture, comm_split_fixture,
                              P_x_ranks, P_x_shape, P_y_ranks, P_y_shape,
                              P_w_ranks, P_w_shape, x_global_shape,
                              y_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.linear import DistributedLinear
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    P_w_base = P_world.create_partition_inclusive(P_w_ranks)
    P_w = P_w_base.create_cartesian_topology_partition(P_w_shape)

    x_global_shape = np.asarray(x_global_shape)
    y_global_shape = np.asarray(y_global_shape)

    layer = DistributedLinear(P_x,
                              P_y,
                              P_w,
                              x_global_shape[1],
                              y_global_shape[1],
                              bias=False)

    x = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape)
        x = torch.Tensor(np.random.randn(*x_local_shape))
    x.requires_grad = True

    y = layer(x)

    dy = zero_volume_tensor(x_global_shape[0])
    if P_y.active:
        dy = torch.Tensor(np.random.randn(*y.shape))

    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)
Exemple #3
0
def test_simple_conv2d_adjoint_weight(barrier_fence_fixture,
                                      comm_split_fixture,
                                      P_x_ranks, P_x_shape,
                                      x_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.conv_feature import DistributedFeatureConv2d
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    x_global_shape = np.asarray(x_global_shape)

    layer = DistributedFeatureConv2d(P_x,
                                     in_channels=x_global_shape[1],
                                     out_channels=10,
                                     kernel_size=[3, 3], bias=False)

    x = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape,
                                         P_x.index,
                                         x_global_shape)
        x = torch.randn(*x_local_shape)
    x.requires_grad = True

    y = layer(x)

    dy = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        dy = torch.randn(*y.shape)

    y.backward(dy)

    W = zero_volume_tensor()
    dW = zero_volume_tensor()
    if P_x.active:
        W = layer.weight.detach()
        dW = layer.weight.grad.detach()

    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, W, dW, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Exemple #4
0
def test_transpose_adjoint(barrier_fence_fixture,
                           comm_split_fixture,
                           P_x_ranks, P_x_shape,
                           P_y_ranks, P_y_shape,
                           x_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    # The global tensor size is the same for x and y
    layer = DistributedTranspose(P_x, P_y, preserve_batch=False)

    # Forward Input
    x = zero_volume_tensor()
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape,
                                         P_x.index,
                                         x_global_shape)
        x = torch.Tensor(np.random.randn(*x_local_shape))
    x.requires_grad = True

    # Adjoint Input
    dy = zero_volume_tensor()
    if P_y.active:
        y_local_shape = compute_subshape(P_y.shape,
                                         P_y.index,
                                         x_global_shape)
        dy = torch.Tensor(np.random.randn(*y_local_shape))

    # y = F @ x
    y = layer(x)

    # dx = F* @ dy
    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)
Exemple #5
0
    def backward(ctx, grad_output):

        P_send = ctx.P_send
        P_recv = ctx.P_recv
        preserve_batch = ctx.preserve_batch
        input_tensor_structure = ctx.input_tensor_structure
        output_tensor_structure = ctx.output_tensor_structure
        dtype = ctx.dtype

        input_requires_grad = input_tensor_structure[0]
        input_tensor_shape = input_tensor_structure[2]
        output_tensor_shape = output_tensor_structure[2]

        # This allows all ranks to use the same exit path, so that we can be
        # sure that all requests have cleared.
        if preserve_batch:
            grad_input = zero_volume_tensor(grad_output.shape[0])
        else:
            grad_input = zero_volume_tensor()

        requests = []

        # If I received data (either from a remote worker or just from myself)
        # I need to reduce that data.  If I send and receive to myself, this
        # is OK, as the reduction accounts for the copy, unlike the broadcast
        # above.
        if P_recv.active:
            reduced_data_recv = np.zeros(output_tensor_shape, dtype=dtype)
            grad_output_numpy = grad_output.detach().numpy()
            req = P_recv.comm.Ireduce(grad_output_numpy,
                                      reduced_data_recv,
                                      root=0,
                                      op=MPI.SUM)
            requests.append(req)

        # If I sent data in the forward, I have to receive it here.  Unless I
        # also received that data, then I already have it from abive.  mpi4py
        # does not allow aliasing of the input, so we have to make a copy of
        # nothing, unfortunately.
        if P_send != P_recv and P_send.active:
            reduced_data_send = np.zeros(input_tensor_shape, dtype=dtype)
            req = P_send.comm.Ireduce(reduced_data_send.copy(),
                                      reduced_data_send,
                                      root=0,
                                      op=MPI.SUM)
            requests.append(req)

        MPI.Request.Waitall(requests)

        # If we had to receive data, we need to tensorify it.
        if P_send.active:
            if P_send == P_recv:
                grad_input = torch.tensor(reduced_data_recv,
                                          requires_grad=input_requires_grad)
            else:
                grad_input = torch.tensor(reduced_data_send,
                                          requires_grad=input_requires_grad)

        return grad_input, None, None, None, None, None, None
Exemple #6
0
def test_broadcast_adjoint(barrier_fence_fixture, comm_split_fixture,
                           P_x_ranks, P_x_shape, P_y_ranks, P_y_shape,
                           x_global_shape, transpose_src):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.broadcast import Broadcast
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    # TODO #93: Change this to create a subtensor so we test when local tensors
    # have different shape.  Then, the output size will also be different, which
    # we will have to get from `y` itself.
    x_local_shape = np.asarray(x_global_shape)

    layer = Broadcast(P_x,
                      P_y,
                      transpose_src=transpose_src,
                      preserve_batch=False)

    x = zero_volume_tensor()
    if P_x.active:
        x = torch.Tensor(np.random.randn(*x_local_shape))
    x.requires_grad = True

    dy = zero_volume_tensor()
    if P_y.active:
        # Adjoint Input
        dy = torch.Tensor(np.random.randn(*x_local_shape))

    # y = F @ x
    y = layer(x)

    # dx = F* @ dy
    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)
    def backward(ctx, grad_output):
        r"""Adjoint function of zero-volume corrector wrapper.

        This method interfaces to the adjoint of the Jacobian of the
        forward  zero-volume corrector operation.

        Parameters
        ----------
        ctx :
            PyTorch context.
        grad_output : `torch.tensor`
            Input tensor.

        Returns
        -------
        output :
            `grad_output` if `input` was not zero-volume, a zero-volume tensor otherwise.

        """

        sh = ctx.sh

        if ctx.zero_volume:
            return zero_volume_tensor(sh[0], device=grad_output.device)
        else:
            return grad_output.clone()
Exemple #8
0
    def forward(ctx, input, P_send, P_recv, preserve_batch,
                input_tensor_structure, output_tensor_structure, dtype):

        ctx.P_send = P_send
        ctx.P_recv = P_recv
        ctx.preserve_batch = preserve_batch
        ctx.input_tensor_structure = input_tensor_structure
        ctx.output_tensor_structure = output_tensor_structure
        ctx.dtype = dtype

        output_requires_grad = output_tensor_structure[0]
        output_tensor_shape = output_tensor_structure[2]

        # This allows all ranks to use the same exit path, so that we can be
        # sure that all requests have cleared.
        if preserve_batch:
            output = zero_volume_tensor(input.shape[0])
        else:
            output = zero_volume_tensor()

        # return output
        requests = []

        # Send all of the data
        if P_send.active:
            input_numpy = input.detach().numpy()
            req = P_send.comm.Ibcast(input_numpy, root=0)
            requests.append(req)

        if P_recv.active:
            # If I also send, make a copy.
            if P_send == P_recv:
                output = input.clone()
            # If I just receive, receive the broadcast
            else:
                output = np.zeros(output_tensor_shape, dtype=dtype)

                req = P_recv.comm.Ibcast(output, root=0)
                req.Wait()
                output = torch.tensor(output,
                                      requires_grad=output_requires_grad)

        MPI.Request.Waitall(requests)

        return output
Exemple #9
0
    def backward(ctx, grad_output):

        partition = ctx.partition
        sh = ctx.sh

        if partition.rank == 0:
            return grad_output.clone(), None
        else:
            return zero_volume_tensor(sh[0]), None
Exemple #10
0
def test_excepts_mismatched_nondivisible_tensor(barrier_fence_fixture,
                                                comm_split_fixture):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.repartition import Repartition
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # A tensor with size 1 in a dimension cannot be partitioned in that
    # dimension.  (See last dimension of output and tensor.)
    in_shape = (1, 4, 1, 1)
    out_shape = (1, 1, 1, 2)
    x_global_shape = np.array([1, 16, 5, 1])

    in_size = np.prod(in_shape)
    out_size = np.prod(out_shape)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size))
    P_x = P_x_base.create_cartesian_topology_partition(in_shape)

    P_y_base = P_world.create_partition_inclusive(
        np.arange(P_world.size - out_size, P_world.size))
    P_y = P_y_base.create_cartesian_topology_partition(out_shape)

    with pytest.raises(ValueError) as e_info:  # noqa: F841
        layer = Repartition(P_x, P_y)
        layer = layer.to(device)

        # Forward Input
        x = zero_volume_tensor(device=device)
        if P_x.active:
            x_local_shape = compute_subshape(P_x.shape, P_x.index,
                                             x_global_shape)
            x = torch.randn(*x_local_shape, device=device)
        x.requires_grad = True

        layer(x)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
Exemple #11
0
    def backward(ctx, grad_output):

        slices = ctx.slices
        buffers = ctx.buffers
        neighbor_ranks = ctx.neighbor_ranks
        P_x = ctx.P_x

        if not P_x.active:
            return zero_volume_tensor(grad_output.shape[0]), None, None, None, None

        if P_x.size == 1:
            return grad_output, None, None, None, None

        grad_output_numpy = grad_output.detach().numpy()

        dim = P_x.dim
        for i in reversed(range(dim)):

            lbs, lgs, rbs, rgs = slices[i]
            lbb, lgb, rbb, rgb = buffers[i]
            lrank, rrank = neighbor_ranks[i]

            if lgb is not None:
                np.copyto(lgb, grad_output_numpy[lgs].ravel())
                grad_output_numpy[lgs] = 0.0
            if rgb is not None:
                np.copyto(rgb, grad_output_numpy[rgs].ravel())
                grad_output_numpy[rgs] = 0.0

            ltag = 0
            rtag = 1

            lrecv_req = P_x.comm.Irecv(lbb, source=lrank, tag=rtag) if lbb is not None else MPI.REQUEST_NULL
            rrecv_req = P_x.comm.Irecv(rbb, source=rrank, tag=ltag) if rbb is not None else MPI.REQUEST_NULL
            lsend_req = P_x.comm.Isend(lgb, dest=lrank, tag=ltag) if lgb is not None else MPI.REQUEST_NULL
            rsend_req = P_x.comm.Isend(rgb, dest=rrank, tag=rtag) if rgb is not None else MPI.REQUEST_NULL

            reqs = [lrecv_req, rrecv_req, lsend_req, rsend_req]
            n_reqs_completed = 0

            while n_reqs_completed < len(reqs):
                status = MPI.Status()
                index = MPI.Request.Waitany(reqs, status)

                if index != MPI.UNDEFINED:
                    if index == 0:
                        newshape = grad_output_numpy[lbs].shape
                        grad_output_numpy[lbs] += lbb.reshape(newshape)
                    elif index == 1:
                        newshape = grad_output_numpy[rbs].shape
                        grad_output_numpy[rbs] += rbb.reshape(newshape)

                n_reqs_completed += 1

        return grad_output, None, None, None, None
Exemple #12
0
    def forward(ctx, input, P_x, slices, buffers, neighbor_ranks):

        ctx.slices = slices
        ctx.buffers = buffers
        ctx.neighbor_ranks = neighbor_ranks
        ctx.P_x = P_x

        if not P_x.active:
            return zero_volume_tensor(input.shape[0])

        ctx.mark_dirty(input)

        if P_x.size == 1:
            return input

        input_numpy = input.detach().numpy()

        dim = P_x.dim
        for i in range(dim):

            lbs, lgs, rbs, rgs = slices[i]
            lbb, lgb, rbb, rgb = buffers[i]
            lrank, rrank = neighbor_ranks[i]

            if lbb is not None:
                np.copyto(lbb, input_numpy[lbs].ravel())
            if rbb is not None:
                np.copyto(rbb, input_numpy[rbs].ravel())

            ltag = 0
            rtag = 1

            lrecv_req = P_x.comm.Irecv(lgb, source=lrank, tag=rtag) if lgb is not None else MPI.REQUEST_NULL
            rrecv_req = P_x.comm.Irecv(rgb, source=rrank, tag=ltag) if rgb is not None else MPI.REQUEST_NULL
            lsend_req = P_x.comm.Isend(lbb, dest=lrank, tag=ltag) if lbb is not None else MPI.REQUEST_NULL
            rsend_req = P_x.comm.Isend(rbb, dest=rrank, tag=rtag) if rbb is not None else MPI.REQUEST_NULL

            reqs = [lrecv_req, rrecv_req, lsend_req, rsend_req]
            n_reqs_completed = 0

            while n_reqs_completed < len(reqs):
                status = MPI.Status()
                index = MPI.Request.Waitany(reqs, status)

                if index != MPI.UNDEFINED:
                    if index == 0:
                        newshape = input_numpy[lgs].shape
                        np.copyto(input_numpy[lgs], lgb.reshape(newshape))
                    elif index == 1:
                        newshape = input_numpy[rgs].shape
                        np.copyto(input_numpy[rgs], rgb.reshape(newshape))

                n_reqs_completed += 1

        return input
Exemple #13
0
    def backward(ctx, grad_output):

        P_send = ctx.P_send
        P_recv = ctx.P_recv
        preserve_batch = ctx.preserve_batch
        input_tensor_structure = ctx.input_tensor_structure
        dtype = ctx.dtype

        input_requires_grad = input_tensor_structure[0]
        input_tensor_shape = input_tensor_structure[2]

        # This allows all ranks to use the same exit path, so that we can be
        # sure that all requests have cleared.
        if preserve_batch:
            grad_input = zero_volume_tensor(grad_output.shape[0])
        else:
            grad_input = zero_volume_tensor()

        requests = []

        # If I received the reduction in the forward call, I broadcast my data
        if P_recv.active:
            grad_output_numpy = grad_output.detach().numpy()
            req = P_recv.comm.Ibcast(grad_output_numpy, root=0)
            requests.append(req)

        # If I just receive, receive the broadcast
        if P_send.active:
            # If I both sent and received reduction data, then I copy the "input"
            if P_send == P_recv:
                grad_input = grad_output.clone()
            else:
                grad_input = np.zeros(input_tensor_shape, dtype=dtype)

                req = P_send.comm.Ibcast(grad_input, root=0)
                req.Wait()
                grad_input = torch.tensor(grad_input,
                                          requires_grad=input_requires_grad)

        MPI.Request.Waitall(requests)

        return grad_input, None, None, None, None, None, None
Exemple #14
0
    def backward(ctx, grad_output):
        r"""Backward function of distributed all-sum-reduction layer.

        This method implements the adjoint of the Jacobian of the
        all-sum-reduce operation, another all-sum-reduce, using the
        ``MPI_Iallreduce`` function.

        When the current worker is inactive in the ``P_allreduce`` partition,
        it will output a zero-volume tensor.

        Parameters
        ----------
        ctx :
            PyTorch context.
        grad_output : `torch.tensor`
            Input tensor.

        Returns
        -------
        grad_input :
            Output tensor.
        """

        P_allreduce = ctx.P_allreduce
        input_tensor_structure = ctx.input_tensor_structure
        device = ctx.device

        grad_input = zero_volume_tensor(device=device)

        requests = []

        # All-sum-reduce is self-adjoint
        if P_allreduce.active:
            numpy_dtype = torch_to_numpy_dtype_dict[
                input_tensor_structure.dtype]

            reduced_data = np.zeros(input_tensor_structure.shape,
                                    dtype=numpy_dtype)
            grad_output_numpy = grad_output.detach().cpu().numpy()
            req = P_allreduce._comm.Iallreduce(grad_output_numpy,
                                               reduced_data,
                                               op=MPI.SUM)
            requests.append(req)

        MPI.Request.Waitall(requests)

        # If we had to receive data, we need to tensorify it.
        if P_allreduce.active:
            grad_input = torch.tensor(
                reduced_data,
                requires_grad=input_tensor_structure.requires_grad,
                device=device)

        return grad_input, None, None, None
Exemple #15
0
def test_simple_conv2d_shape(barrier_fence_fixture,
                             comm_split_fixture,
                             P_x_ranks, P_x_shape,
                             x_global_shape,
                             y_local_shape,
                             padding):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.conv_feature import DistributedFeatureConv2d
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    x_global_shape = np.asarray(x_global_shape)

    layer = DistributedFeatureConv2d(P_x,
                                     in_channels=x_global_shape[1],
                                     out_channels=10,
                                     kernel_size=[3, 3],
                                     padding=padding,
                                     bias=False)

    x = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape,
                                         P_x.index,
                                         x_global_shape)
        x = torch.zeros(*x_local_shape)
    x.requires_grad = True

    y = layer(x)

    if P_x.active:
        assert(np.array_equal(np.array(y.shape), np.asarray(y_local_shape)))

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Exemple #16
0
def test_excepts_mismatched_output_partition_tensor(barrier_fence_fixture,
                                                    comm_split_fixture):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Output partition rank must match tensor rank
    in_shape = (4, 1, 1)
    out_shape = (1, 1, 1, 2)
    x_global_shape = np.array([16, 5, 5])

    in_size = np.prod(in_shape)
    out_size = np.prod(out_shape)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size))
    P_x = P_x_base.create_cartesian_topology_partition(in_shape)

    P_y_base = P_world.create_partition_inclusive(np.arange(P_world.size-out_size, P_world.size))
    P_y = P_y_base.create_cartesian_topology_partition(out_shape)

    with pytest.raises(ValueError) as e_info:  # noqa: F841
        layer = DistributedTranspose(P_x, P_y)

        # Forward Input
        x = zero_volume_tensor()
        if P_x.active:
            x_local_shape = compute_subshape(P_x.shape,
                                             P_x.index,
                                             x_global_shape)
            x = torch.Tensor(np.random.randn(*x_local_shape))
        x.requires_grad = True

        layer(x)
Exemple #17
0
def test_batch_norm_no_training(barrier_fence_fixture, P_x_ranks, P_x_shape,
                                input_shape, num_features, eps, momentum,
                                affine, track_running_stats,
                                comm_split_fixture):

    from distdl.backends.mpi.partition import MPIPartition

    device = torch.device('cuda' if use_cuda else 'cpu')

    torch.manual_seed(0)

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    num_dimensions = len(input_shape)
    P_in_out_base = P_world.create_partition_inclusive([0])
    P_in_out = P_in_out_base.create_cartesian_topology_partition(
        [1] * num_dimensions)
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    # Create the input
    if P_world.rank == 0:
        input_eval = torch.rand(input_shape,
                                dtype=torch.float32,
                                device=device)
    else:
        input_eval = zero_volume_tensor(device=device)

    # Create the sequential network
    if P_world.rank == 0:
        seq_net = torch.nn.Sequential(
            torch.nn.BatchNorm1d(num_features=num_features,
                                 eps=eps,
                                 momentum=momentum,
                                 affine=affine,
                                 track_running_stats=track_running_stats))
        seq_net = seq_net.to(device)
    else:
        seq_net = None

    # Evaluate sequential network
    if P_world.rank == 0:
        seq_net.eval()
        seq_out = seq_net(input_eval)
        seq_out = seq_out.detach().cpu()

    # Create distributed network
    dist_net = torch.nn.Sequential(
        distdl.nn.Repartition(P_in_out, P_x),
        distdl.nn.DistributedBatchNorm(
            P_x,
            num_features=num_features,
            eps=eps,
            momentum=momentum,
            affine=affine,
            track_running_stats=track_running_stats),
        distdl.nn.Repartition(P_x, P_in_out))
    dist_net = dist_net.to(device)

    # Evaluate distributed network
    dist_net.eval()
    dist_out = dist_net(input_eval)
    dist_out = dist_out.detach().cpu()

    # Compare the distributed and sequential networks
    if P_world.rank == 0:
        assert dist_out.shape == seq_out.shape

        # Set the absolute tolerance to ~sqrt(e_mach), or the default
        # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit
        # floats, not 32-bit floats as torch does.  Consequently, the default
        # torch atol is actually tighter than one can expect from two fp-equal
        # floating point numbers.  The NumPy default of 1e-8 is closer to
        # sqrt(e_mach) for 64-bit numbers.  So we set the 32-bit tolerance to
        # a little tighter than sqrt(1e-7), 1e-5.
        if seq_out.dtype == torch.float64:
            atol = 1e-8
        elif seq_out.dtype == torch.float32:
            atol = 1e-5
        else:
            # torch default
            atol = 1e-8
        assert torch.allclose(dist_out, seq_out, atol=atol)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_in_out_base.deactivate()
    P_in_out.deactivate()
Exemple #18
0
    def forward(ctx, input, P_send, P_recv, preserve_batch,
                input_tensor_structure, output_tensor_structure):
        r"""Forward function of distributed sum-reduction layer.

        This method implements the forward sum-reduction operation using the
        ``MPI_Ireduce`` function.

        Any given worker may participate in two MPI reductions, one on the
        ``P_send`` partition and one on the ``P_recv`` partition.  The
        communication pattern and function selection is to avoid potential
        deadlocks due to potential overlaps in these partitions.

        When the current worker is active in its ``P_send`` partition, it
        *always* has data that must be reduced.  Therefore it will always send
        data (through a sum-reduce) to the root of that partition.

        If the current worker is active in ``P_recv`` then it is guaranteed to be the root
        worker of ``P_recv`` and there are two potential scenerios.

        1. If the ``P_send`` and ``P_recv`` partitions are distinct, the
           current worker will receive reduced tensor data as the root of an
           additional ``MPI_Ireduce``.
        2. If the ``P_send`` and ``P_recv`` partitions are the same, the
           reduction is completed by the *first* ``MPI_Ireduce`` and the second
           is not necessary, and in fact will cause a deadlock.

        When the current worker is inactive in the ``P_recv`` partition, it will
        output a zero-volume tensor, potentially preserving a non-zero batch
        size.

        Parameters
        ----------
        ctx :
            PyTorch context.
        input : `torch.tensor`
            Input tensor.
        P_send : Partition
            Sending partition current worker is a part of.
        P_recv : Partition
            Receiving partition current worker is a part of.
        preserve_batch : bool
            Indicates if batch size should be preserved for zero-volume outputs.
        input_tensor_structure : tuple
            Tuple containing properties of the input tensor (dimension, shape,
            requires_grad).
        output_tensor_structure : tuple
            Tuple containing properties of the output tensor (dimension, shape,
            requires_grad).

        Returns
        -------
        output :
            Output tensor.

        """

        device = input.device
        ctx.P_send = P_send
        ctx.P_recv = P_recv
        ctx.preserve_batch = preserve_batch
        ctx.input_tensor_structure = input_tensor_structure
        ctx.output_tensor_structure = output_tensor_structure
        ctx.device = device

        # This allows all ranks to use the same exit path, so that we can be
        # sure that all requests have cleared.
        if preserve_batch:
            output = zero_volume_tensor(input.shape[0], device=device)
        else:
            output = zero_volume_tensor(device=device)

        requests = []

        # By design, the roots are always 0 in the cross-communicators
        # If I receive data (either from a remote worker or just from myself)
        # I need to reduce that data.  If I send and receive to myself, this
        # is OK, as the reduction accounts for the copy, unlike the broadcast
        # below.
        if P_send.active:
            numpy_dtype = torch_to_numpy_dtype_dict[
                input_tensor_structure.dtype]
            reduced_data_send = np.zeros(input_tensor_structure.shape,
                                         dtype=numpy_dtype)
            input_numpy = input.detach().cpu().numpy()
            req = P_send._comm.Ireduce(input_numpy,
                                       reduced_data_send,
                                       root=0,
                                       op=MPI.SUM)
            requests.append(req)

        # If I sent data in the forward, I have to receive it here.
        if P_send != P_recv and P_recv.active:
            numpy_dtype = torch_to_numpy_dtype_dict[
                output_tensor_structure.dtype]
            reduced_data_recv = np.zeros(output_tensor_structure.shape,
                                         dtype=numpy_dtype)
            req = P_recv._comm.Ireduce(MPI.IN_PLACE,
                                       reduced_data_recv,
                                       root=0,
                                       op=MPI.SUM)
            requests.append(req)

        MPI.Request.Waitall(requests)

        # If we had to receive data, we need to tensorify it.
        if P_recv.active:
            if P_send == P_recv:
                output = torch.tensor(
                    reduced_data_send,
                    requires_grad=output_tensor_structure.requires_grad,
                    device=device)
            else:
                output = torch.tensor(
                    reduced_data_recv,
                    requires_grad=output_tensor_structure.requires_grad,
                    device=device)

        return output
Exemple #19
0
def test_batch_norm_with_training(barrier_fence_fixture, P_x_ranks, P_x_shape,
                                  input_shape, num_features, eps, momentum,
                                  affine, track_running_stats, affine_workers,
                                  comm_split_fixture):

    from distdl.backends.mpi.partition import MPIPartition

    torch.manual_seed(0)

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    num_dimensions = len(input_shape)
    P_in_out_base = P_world.create_partition_inclusive([0])
    P_in_out = P_in_out_base.create_cartesian_topology_partition(
        [1] * num_dimensions)
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)
    P_affine = P_world.create_partition_inclusive(affine_workers)

    # Create the input
    if P_world.rank == 0:
        input_train = torch.rand(input_shape,
                                 dtype=torch.float32,
                                 device=device)
        input_eval = torch.rand(input_shape,
                                dtype=torch.float32,
                                device=device)
        exp = torch.rand(input_shape, dtype=torch.float32, device=device)
    else:
        input_train = zero_volume_tensor(device=device)
        input_eval = zero_volume_tensor(device=device)
        exp = zero_volume_tensor(device=device)

    # Create the sequential network
    if len(input_shape) == 2:
        seq_layer = torch.nn.BatchNorm1d
    elif len(input_shape) == 3:
        seq_layer = torch.nn.BatchNorm1d
    elif len(input_shape) == 4:
        seq_layer = torch.nn.BatchNorm2d
    elif len(input_shape) == 5:
        seq_layer = torch.nn.BatchNorm3d
    if P_world.rank == 0:
        seq_bn = seq_layer(num_features=num_features,
                           eps=eps,
                           momentum=momentum,
                           affine=affine,
                           track_running_stats=track_running_stats)
        seq_bn = seq_bn.to(device)

    # Train sequential network
    if P_world.rank == 0:
        seq_bn.train()
        seq_out1 = seq_bn(input_train)
        seq_loss = ((seq_out1 - exp)**2).sum()
        seq_loss.backward()
        seq_grads = [p.grad.detach().cpu() for p in seq_bn.parameters()]
        # Do a manual weight update (this is what optimizer does):
        with torch.no_grad():
            for p in seq_bn.parameters():
                p.copy_(p + 0.1 * p.grad)

    # Evaluate sequential network
    if P_world.rank == 0:
        seq_bn.eval()
        seq_out2 = seq_bn(input_eval)
        seq_out2 = seq_out2.detach().cpu()

    # Create distributed network
    tr1 = distdl.nn.Repartition(P_in_out, P_x)
    tr1 = tr1.to(device)
    dist_bn = distdl.nn.DistributedBatchNorm(
        P_x,
        num_features=num_features,
        eps=eps,
        momentum=momentum,
        affine=affine,
        track_running_stats=track_running_stats)
    dist_bn = dist_bn.to(device)
    tr2 = distdl.nn.Repartition(P_x, P_in_out)
    tr2 = tr2.to(device)

    # Only rank 0 should have trainable parameters:
    if P_world.rank in affine_workers:
        assert len(list(dist_bn.parameters())) == 2
    else:
        assert len(list(dist_bn.parameters())) == 0

    # Train distributed network
    dist_bn.train()
    dist_out1 = tr2(dist_bn(tr1(input_train)))
    dist_loss = ((dist_out1 - exp)**2).sum()
    assert dist_loss.requires_grad
    dist_loss.backward()
    # Note: We expect the batch norm gradient to have extra dimensions than PyTorch,
    #       but both ultimately have volume equal to num_features.
    #       So, reshape them now, and gather them onto rank 0 for comparison.
    if P_world.rank == 0:
        dist_grads = []
    if P_world.rank in affine_workers:
        for p in dist_bn.parameters():
            if affine_workers == [0]:
                parts = [p.grad.detach().cpu()]
            else:
                parts = P_affine._comm.gather(p.grad.detach().cpu(), root=0)
            if P_world.rank == 0:
                grad = torch.cat(parts, 1)
                reshaped = grad.reshape((num_features, ))
                dist_grads.append(reshaped)
    # Do a manual weight update (this is what optimizer does):
    with torch.no_grad():
        for p in dist_bn.parameters():
            p.copy_(p + 0.1 * p.grad)

    # Evaluate distributed network
    dist_bn.eval()
    dist_out2 = tr2(dist_bn(tr1(input_eval)))
    dist_out2 = dist_out2.detach().cpu()

    # Compare the distributed and sequential networks
    if P_world.rank == 0:

        # Set the absolute tolerance to ~sqrt(e_mach), or the default
        # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit
        # floats, not 32-bit floats as torch does.  Consequently, the default
        # torch atol is actually tighter than one can expect from two fp-equal
        # floating point numbers.  The NumPy default of 1e-8 is closer to
        # sqrt(e_mach) for 64-bit numbers.  So we set the 32-bit tolerance to
        # sqrt(1e-7), as our usual choice, 1e-5, is too tight.
        if seq_out1.dtype == torch.float64:
            atol = 1e-8
        elif seq_out1.dtype == torch.float32:
            import math
            atol = math.sqrt(1e-7)
        else:
            # torch default
            atol = 1e-8

        assert dist_out1.shape == seq_out1.shape
        assert torch.allclose(dist_out1,
                              seq_out1,
                              rtol=ERROR_THRESHOLD,
                              atol=atol)
        assert dist_loss.shape == seq_loss.shape
        assert torch.allclose(dist_loss,
                              seq_loss,
                              rtol=ERROR_THRESHOLD,
                              atol=atol)
        for dist_grad, seq_grad in zip(dist_grads, seq_grads):
            assert dist_grad.shape == seq_grad.shape
            assert torch.allclose(dist_grad,
                                  seq_grad,
                                  rtol=ERROR_THRESHOLD,
                                  atol=atol)
        assert dist_out2.shape == seq_out2.shape
        assert torch.allclose(dist_out2,
                              seq_out2,
                              rtol=ERROR_THRESHOLD,
                              atol=atol)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_in_out_base.deactivate()
    P_in_out.deactivate()
    P_affine.deactivate()
Exemple #20
0
    def backward(ctx, grad_output):
        r"""Backward function of distributed sum-reduction layer.

        This method implements the adjoint of the Jacobian of the sum-reduce
        operation, a sum-reduce, using the ``MPI_Ibcast`` function.

        The roles of the respective send and receive partitions are reversed in
        the adjoint algorithm.  Any worker that was the source of reduced data
        in the forward algorithm will be the destination of broadcast data in
        the adjoint.

        Any given worker may participate in two MPI broadcasts, one on the
        ``P_recv`` partition and one on the ``P_send`` partition.  The
        communication pattern and function selection is to avoid potential
        deadlocks due to potential overlaps in these partitions.

        When the current worker is active in its ``P_recv`` partition, it
        *always* has data that it must share.  It will only be active in
        ``P_recv`` if it is the root worker of that partition, therefore, it
        will send tensor data as the root of an ``MPI_Ibcast``.

        When the current worker is active in its ``P_send`` partition, there are
        multiple potential scenerios.

        1. If it is *active* in a ``P_recv`` partition and ``P_recv`` is the
           *same* partition as ``P_send``, then the input subtensor can simply
           be cloned for the output.
        2. If it is *active* in a ``P_recv`` partition and ``P_recv`` is a
           *different* partition from ``P_send``, then it will receive tensor
           data from the root of an ``MPI_Ibcast``.
        3. If it is *inactive* in a ``P_recv`` partition, then it will receive
           tensor data from the root of an ``MPI_Ibcast``.

        When the current worker is inactive in the ``P_send`` partition, it will
        output a zero-volume tensor, potentially preserving a non-zero batch
        size.

        Parameters
        ----------
        ctx :
            PyTorch context.
        grad_output : `torch.tensor`
            Input tensor.

        Returns
        -------
        grad_input :
            Output tensor.
        """

        P_send = ctx.P_send
        P_recv = ctx.P_recv
        preserve_batch = ctx.preserve_batch
        input_tensor_structure = ctx.input_tensor_structure
        device = ctx.device

        assert grad_output.device == device

        # This allows all ranks to use the same exit path, so that we can be
        # sure that all requests have cleared.
        if preserve_batch:
            grad_input = zero_volume_tensor(grad_output.shape[0],
                                            device=device)
        else:
            grad_input = zero_volume_tensor(device=device)

        requests = []

        # If I received the reduction in the forward call, I broadcast my data
        if P_recv.active:
            grad_output_numpy = grad_output.detach().cpu().numpy()
            req = P_recv._comm.Ibcast(grad_output_numpy, root=0)
            requests.append(req)

        # If I just receive, receive the broadcast
        if P_send.active:
            # If I both sent and received reduction data, then I copy the "input"
            if P_send == P_recv:
                grad_input = grad_output.clone()
            else:
                numpy_dtype = torch_to_numpy_dtype_dict[
                    input_tensor_structure.dtype]
                grad_input = np.zeros(input_tensor_structure.shape,
                                      dtype=numpy_dtype)

                req = P_send._comm.Ibcast(grad_input, root=0)
                req.Wait()
                grad_input = torch.tensor(
                    grad_input,
                    requires_grad=input_tensor_structure.requires_grad,
                    device=device)

        MPI.Request.Waitall(requests)

        return grad_input, None, None, None, None, None
Exemple #21
0
def test_distributed_loss(barrier_fence_fixture, comm_split_fixture, P_x_ranks,
                          P_x_shape, x_global_shape, SequentialLoss,
                          DistributedLoss):

    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn import DistributedTranspose
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_0_base = P_x_base.create_partition_inclusive([0])
    P_0 = P_0_base.create_cartesian_topology_partition([1] * len(P_x_shape))

    scatter = DistributedTranspose(P_0, P_x).to(device)
    gather = DistributedTranspose(P_x, P_0).to(device)

    for reduction in DistributedLoss._valid_reductions:

        distributed_criterion = DistributedLoss(P_x,
                                                reduction=reduction).to(device)
        sequential_criterion = SequentialLoss(reduction=reduction).to(device)

        with torch.no_grad():
            x_g = zero_volume_tensor(device=device)
            y_g = zero_volume_tensor(device=device)
            if P_0.active:
                x_g = torch.rand(x_global_shape, device=device)
                y_g = torch.rand(x_global_shape, device=device)

            x_l = scatter(x_g)
            y_l = scatter(y_g)

        x_l.requires_grad = True
        distributed_loss = distributed_criterion(x_l, y_l)

        # For "none", no reduction is applied so we see if it computed the
        # same loss as the sequential code by gathering the loss value it to
        # the root rank.
        if reduction == "none":
            distributed_loss = gather(distributed_loss)

        if P_0.active:
            x_g.requires_grad = True
            sequential_loss = sequential_criterion(x_g, y_g)

            assert (torch.allclose(distributed_loss, sequential_loss))

        # For any other reduction, we can compare the loss
        # value *and* backpropagate through the distributed loss to verify
        # that it produces the same output.
        if reduction != "none":
            distributed_loss.backward()
            distributed_dx_g = gather(x_l.grad)

            if P_0.active:
                sequential_loss.backward()
                sequential_dx_g = x_g.grad

                assert (torch.allclose(distributed_dx_g, sequential_dx_g))

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_0_base.deactivate()
    P_0.deactivate()
Exemple #22
0
def test_sum_reduce_dtype(barrier_fence_fixture, comm_split_fixture, dtype,
                          test_backward, P_x_ranks, P_x_shape, P_y_ranks,
                          P_y_shape, x_global_shape, transpose_src):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.sum_reduce import SumReduce
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    # TODO #93: Change this to create a subtensor so we test when local tensors
    # have different shape.  Then, the output size will also be different, which
    # we will have to get from `y` itself.
    x_local_shape = np.asarray(x_global_shape)

    layer = SumReduce(P_x,
                      P_y,
                      transpose_src=transpose_src,
                      preserve_batch=False)
    layer = layer.to(device)

    x = zero_volume_tensor(device=device)
    if P_x.active:
        x = 10 * torch.randn(*x_local_shape, device=device).to(dtype)

    x.requires_grad = test_backward

    # y = F @ x
    y = layer(x)

    # If we are not in the output partition, there is no data to test the type
    # against.
    if P_y.active:
        assert y.dtype == dtype

    if test_backward:
        dy = zero_volume_tensor(device=device)
        if P_y.active:
            # Adjoint Input
            dy = 10 * torch.randn(*x_local_shape, device=device).to(dtype)

        # dx = F* @ dy
        y.backward(dy)
        dx = x.grad

        if P_x.active:
            assert dx.dtype == dtype

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
Exemple #23
0
    def backward(ctx, grad_output):
        r"""Adjoint function of distributed transpose layer.

        This method implements the adjoint of the Jacobian of the transpose
        operation using MPI immediate-mode, non-blocking communication.

        The roles of the ``P_x`` and ``P_y`` partitions are reversed, but all
        communication across partitions occurs through the ``P_union``
        partition.

        Data is copied using ``MPI_Irecv`` and ``MPI_Isend``. As is standard
        procedure, the receives are posted first, allowing them to complete as
        they can.  Then, buffers are packed and sent.  Once all sends have
        been posted, received data is unpacked in the order that the receives
        complete.

        When the current worker is inactive in the ``P_x`` partition, it will
        output a zero-volume tensor, potentially preserving a non-zero batch

        Parameters
        ----------
        ctx :
            PyTorch context.
        grad_output : `torch.tensor`
            Input tensor.

        Returns
        -------
        output :
            Output tensor.

        """

        P_union = ctx.P_union
        x_global_structure = ctx.x_global_structure
        x_local_structure = ctx.x_local_structure

        P_x = ctx.P_x
        P_x_to_y_overlaps = ctx.P_x_to_y_overlaps
        P_x_to_y_buffers = ctx.P_x_to_y_buffers

        P_y = ctx.P_y
        P_y_to_x_overlaps = ctx.P_y_to_x_overlaps
        P_y_to_x_buffers = ctx.P_y_to_x_buffers

        preserve_batch = ctx.preserve_batch

        input_requires_grad = ctx.input_requires_grad

        device = ctx.device

        assert grad_output.device == device

        requests = []

        # Default everyone to output None
        if preserve_batch:
            grad_input = zero_volume_tensor(grad_output.shape[0],
                                            dtype=x_global_structure.dtype,
                                            device=device)
        else:
            grad_input = zero_volume_tensor(dtype=x_global_structure.dtype,
                                            device=device)

        # Recv my input parts
        recv_count = 0
        if P_x.active:
            for (sl, sh, partner), buff in zip(P_x_to_y_overlaps,
                                               P_x_to_y_buffers):
                if buff is not None:
                    xfer_buff = buff.get_view(sh)
                    req = P_union._comm.Irecv(xfer_buff,
                                              source=partner,
                                              tag=113)
                    requests.append(req)
                else:
                    # We add this if there is no recv so that the indices of
                    # the requests array match the indices of
                    # P_x_to_y_overlaps and P_x_to_y_buffers.
                    requests.append(MPI.REQUEST_NULL)
                recv_count += 1

        # Pack and send my input parts
        send_count = 0
        if P_y.active:
            for (sl, sh, partner), buff in zip(P_y_to_x_overlaps,
                                               P_y_to_x_buffers):
                if buff is not None:
                    xfer_buff = buff.get_view(sh)
                    np.copyto(xfer_buff,
                              grad_output.detach()[sl].cpu().numpy())
                    req = P_union._comm.Isend(xfer_buff, dest=partner, tag=113)
                    requests.append(req)
                else:
                    # We add this for symmetry, but don't really need it.
                    requests.append(MPI.REQUEST_NULL)
                send_count += 1

        if P_x.active:
            numpy_dtype = torch_to_numpy_dtype_dict[x_global_structure.dtype]
            grad_input = np.zeros(x_local_structure.shape, dtype=numpy_dtype)

        # Handle the self-copy
        if P_y.active and P_x.active:
            # Find the self patch in x_to_y
            for (ysl, ysh, y2xpartner) in P_y_to_x_overlaps:
                if y2xpartner == "self":
                    for (xsl, xsh, x2ypartner) in P_x_to_y_overlaps:
                        if x2ypartner == "self":
                            np.copyto(grad_input[xsl],
                                      grad_output.detach()[ysl].cpu().numpy())
                            # There is only one case where this can happen
                            break
                    # There is only one case where this can happen
                    break

        # Unpack the received data as it arrives
        completed_count = 0
        while (completed_count < len(requests)):
            status = MPI.Status()
            index = MPI.Request.Waitany(requests, status)

            # In MPI, we don't get the index out if the request is an
            # instance of MPI.REQUEST_NULL, instead MPI.UNDEFINED is returned.
            if P_x.active and index < recv_count and index != MPI.UNDEFINED:
                # Unpack my output parts
                sl, sh, partner = P_x_to_y_overlaps[index]
                buff = P_x_to_y_buffers[index]
                if buff is not None:
                    xfer_buff = buff.get_view(sh)
                    # This would normally be an add into the grad_input tensor
                    # but we just created it, so a copy is sufficient.
                    np.copyto(grad_input[sl], xfer_buff)

            completed_count += 1

        if P_x.active:
            grad_input = torch.tensor(grad_input,
                                      requires_grad=input_requires_grad,
                                      device=device)

        return grad_input, None, None, None, None, None, None, None, None, None, None, None
# Create the transpose layer
layer = Repartition(P_x, P_y, preserve_batch=False)

# Setup the input tensor.  Any worker in P_x will generate its part of the
# input tensor.  Any worker not in P_x will have a zero-volume tensor.
#
# Input tensor will be (on a 1 x 1 partition):
# [ [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ] ]
x = zero_volume_tensor()
if P_x.active:
    x_local_shape = slicing.compute_subshape(P_x.shape, P_x.index,
                                             x_global_shape)
    x = np.zeros(x_local_shape) + P_x.rank + 1
    x = torch.from_numpy(x)
x.requires_grad = True
print(f"rank {P_world.rank}; index {P_x.index}; value {x}")

# Apply the layer.
#
# Output tensor will be (on a 2 x 2 partition):
# [ [ 1 1 1 | 1 1 ]
#   [ 1 1 1 | 1 1 ]
#   [ 1 1 1 | 1 1 ]
#   [ 1 1 1 | 1 1 ]
Exemple #25
0
    def forward(ctx, input, P_union, x_global_structure, x_local_structure,
                y_local_structure, P_x, P_x_to_y_overlaps, P_x_to_y_buffers,
                P_y, P_y_to_x_overlaps, P_y_to_x_buffers, preserve_batch):
        r"""Forward function of distributed transpose layer.

        This method implements the forward transpose operation using MPI
        immediate-mode, non-blocking communication.

        Any given worker may send data to multiple workers in ``P_y`` and
        receive data from multiple workers in ``P_x``.  All communication
        across partitions occurs through the ``P_union`` partition.

        Data is copied using ``MPI_Irecv`` and ``MPI_Isend``. As is standard
        procedure, the receives are posted first, allowing them to complete as
        they can.  Then, buffers are packed and sent.  Once all sends have
        been posted, received data is unpacked in the order that the receives
        complete.

        When the current worker is inactive in the ``P_y`` partition, it will
        output a zero-volume tensor, potentially preserving a non-zero batch
        size.

        Parameters
        ----------
        ctx :
            PyTorch context.
        input : `torch.tensor`
            Input tensor.
        P_union : Partition
            Partition through which all communication occurs.
        x_global_structure :
            Structure of the global input tensor.
        x_local_structure :
            Structure of the local input tensor.
        y_local_structure :
            Structure of the local output tensor.
        P_x : Partition
            Input partition.
        P_x_to_y_overlaps : list
            List of tuples (sl, sh, partner) for each send current worker must
            perform.
        P_x_to_y_buffers : list
            List of pre-allocated send buffers for each send current worker
            must perform.
        P_y : Partition
            Input partition.
        P_y_to_x_overlaps : list
            List of tuples (sl, sh, partner) for each receive current worker
            must perform.
        P_y_to_x_buffers : list
            List of pre-allocated send buffers for each receive current worker
            must perform.
        preserve_batch : bool
            Indicates if batch size should be preserved for zero-volume outputs.

        Returns
        -------
        output :
            Output tensor.

        """

        ctx.P_union = P_union
        ctx.x_global_structure = x_global_structure
        ctx.x_local_structure = x_local_structure

        ctx.P_x = P_x
        ctx.P_x_to_y_overlaps = P_x_to_y_overlaps
        ctx.P_x_to_y_buffers = P_x_to_y_buffers

        ctx.P_y = P_y
        ctx.P_y_to_x_overlaps = P_y_to_x_overlaps
        ctx.P_y_to_x_buffers = P_y_to_x_buffers

        ctx.preserve_batch = preserve_batch

        device = input.device
        ctx.device = device

        input_requires_grad = False

        # Share the requires-grad status, so that it is preserved across the
        # transpose
        if P_union.active:
            # By design, P_x is always first in the union, so we can just take
            # rank 0's status to send
            if P_x.rank == 0:
                input_requires_grad = input.requires_grad
                P_union._comm.Bcast(np.array([1 if input_requires_grad else 0
                                              ]),
                                    root=0)
            else:
                irg = np.array([0], dtype=np.int)
                P_union._comm.Bcast(irg, root=0)
                input_requires_grad = bool(irg[0] == 1)

        ctx.input_requires_grad = input_requires_grad

        requests = []

        # Default everyone to output nothing
        if preserve_batch:
            output = zero_volume_tensor(input.shape[0],
                                        dtype=x_global_structure.dtype,
                                        device=device)
        else:
            output = zero_volume_tensor(dtype=x_global_structure.dtype,
                                        device=device)

        # If I am getting data, recv my output parts
        recv_count = 0
        if P_y.active:
            for (sl, sh, partner), buff in zip(P_y_to_x_overlaps,
                                               P_y_to_x_buffers):
                if buff is not None:
                    xfer_buff = buff.get_view(sh)
                    req = P_union._comm.Irecv(xfer_buff,
                                              source=partner,
                                              tag=111)
                    requests.append(req)
                else:
                    # We add this if there is no recv so that the indices of
                    # the requests array match the indices of
                    # P_y_to_x_overlaps and P_y_to_x_buffers.
                    requests.append(MPI.REQUEST_NULL)
                recv_count += 1

        # If I have data to share, pack and send my input parts
        send_count = 0
        if P_x.active:
            for (sl, sh, partner), buff in zip(P_x_to_y_overlaps,
                                               P_x_to_y_buffers):
                if buff is not None:
                    xfer_buff = buff.get_view(sh)
                    np.copyto(xfer_buff, input.detach()[sl].cpu().numpy())
                    req = P_union._comm.Isend(xfer_buff, dest=partner, tag=111)
                    requests.append(req)
                else:
                    # We add this for symmetry, but don't really need it.
                    requests.append(MPI.REQUEST_NULL)
                send_count += 1

        # We do this after the sends so that they can get started before local
        # allocations.
        if P_y.active:
            numpy_dtype = torch_to_numpy_dtype_dict[x_global_structure.dtype]
            output = np.zeros(y_local_structure.shape, dtype=numpy_dtype)

        # Handle the self-copy
        if P_x.active and P_y.active:
            # Find the self patch in x_to_y
            for (xsl, xsh, x2ypartner) in P_x_to_y_overlaps:
                if x2ypartner == "self":
                    for (ysl, ysh, y2xpartner) in P_y_to_x_overlaps:
                        if y2xpartner == "self":
                            np.copyto(output[ysl],
                                      input.detach()[xsl].cpu().numpy())
                            # There is only one case where this can happen
                            break
                    # There is only one case where this can happen
                    break

        # Unpack the received data as it arrives
        completed_count = 0
        while (completed_count < len(requests)):
            status = MPI.Status()
            index = MPI.Request.Waitany(requests, status)

            # In MPI, we don't get the index out if the request is an
            # instance of MPI.REQUEST_NULL, instead MPI.UNDEFINED is returned.
            if P_y.active and index < recv_count and index != MPI.UNDEFINED:
                # Unpack my output parts
                sl, sh, partner = P_y_to_x_overlaps[index]
                buff = P_y_to_x_buffers[index]
                if buff is not None:
                    xfer_buff = buff.get_view(sh)
                    np.copyto(output[sl], xfer_buff)

            completed_count += 1

        if P_y.active:
            output = torch.tensor(output,
                                  requires_grad=input_requires_grad,
                                  device=device)

        return output
Exemple #26
0
    def __init__(self,
                 P_x,
                 P_y,
                 P_w,
                 in_channels=1,
                 out_channels=1,
                 bias=True,
                 *args,
                 **kwargs):

        super(DistributedGeneralConvBase, self).__init__()

        # P_x is 1    x P_ci x P_d-1 x ... x P_0
        self.P_x = P_x
        # P_y is 1    x P_co x P_d-1 x ... x P_0
        self.P_y = P_y
        # P_w is P_co x P_ci x P_d-1 x ... x P_0
        self.P_w = P_w

        self.P_union = self._distdl_backend.Partition()
        if not (self.P_x.active or self.P_y.active or self.P_w.active):
            return

        # This guarantees that P_union rank 0 has the kernel size, stride,
        # padding, and dilation factors
        P_union = P_w.create_partition_union(P_x)
        P_union = P_union.create_partition_union(P_y)
        self.P_union = P_union

        P_w_shape = None
        if P_union.rank == 0:
            P_w_shape = np.array(P_w.shape, dtype=np.int)
        P_w_shape = P_union.broadcast_data(P_w_shape, root=0)

        P_co = P_w_shape[0]
        P_ci = P_w_shape[1]
        P_channels = [P_co, P_ci]

        P_x_new_shape = []
        if self.P_x.active:
            if (np.any(P_x.shape[2:] != P_w_shape[2:])):
                raise ValueError(
                    "Spatial components of P_x and P_w must match.")
            if P_w_shape[1] != P_x.shape[1]:
                raise ValueError(
                    "Index 2 of P_w dimension must match input channel partition."
                )
            P_x_new_shape = list(P_x.shape)
            P_x_new_shape.insert(1, 1)
            # Currently a hack, removing the batch dimension because P_w does
            # not have one. This is OK because we assume there are no partitions
            # in the batch dimension.
            P_x_new_shape = np.asarray(P_x_new_shape[1:], dtype=int)

        # For the purposes of this layer, we re-cast P_x to have the extra
        # dimension.  This has no impact outside of the layer or on the results.
        self.P_x = self.P_x.create_cartesian_topology_partition(P_x_new_shape)

        P_y_new_shape = []
        if self.P_y.active:
            if (np.any(P_y.shape[2:] != P_w_shape[2:])):
                raise ValueError(
                    "Spatial components of P_y and P_w must match.")
            if P_w_shape[0] != P_y.shape[1]:
                raise ValueError(
                    "Index 1 of P_w dimension must match output channel partition."
                )
            P_y_new_shape = list(P_y.shape)
            P_y_new_shape.insert(2, 1)
            # Currently a hack, removing the batch dimension because P_w does
            # not have one. This is OK because we assume there are no partitions
            # in the batch dimension.
            P_y_new_shape = np.asarray(P_y_new_shape[1:], dtype=int)

        # For the purposes of this layer, we re-cast P_x to have the extra
        # dimension.  This has no impact outside of the layer or on the results.
        self.P_y = self.P_y.create_cartesian_topology_partition(P_y_new_shape)

        P_spatial = P_w_shape[2:]

        self.serial = False
        if self.P_w.size == 1:
            self.serial = True
            self.conv_layer = self.TorchConvType(*args, **kwargs)
            return

        self.receives_weight = False
        self.stores_weight = False
        self.receives_bias = False
        self.stores_bias = False

        # Determine P_r, initialize weights there
        if self.P_w.active:
            # All of P_w always receives the weight
            self.receives_weight = True

            # This subset is taken to be the origin of the spartial component
            w_root_subset = []
            for i, c in enumerate(range_index(P_w.shape)):
                c = np.asarray(c)
                # Find the P_co x P_ci x 1 x ... x 1 subset to store the weights
                if np.all(c[2:] == 0):
                    w_root_subset.append(i)

            self.P_wr_base = self.P_w.create_partition_inclusive(w_root_subset)
            # ones are needed so the broadcast will work
            self.P_wr = self.P_wr_base.create_cartesian_topology_partition(
                [P_co, P_ci] + [1] * len(P_spatial))
            self.stores_weight = self.P_wr.active

            b_subset = []
            for i, c in enumerate(range_index(P_w.shape)):
                c = np.asarray(c)
                # Find the P_co x 1 x P_0 x ... x P_D-1 subset that needs biases in its calculation.
                # This is everywhere that the input channels is rank 0.
                if c[1] == 0:
                    b_subset.append(i)

            self.P_b_base = self.P_w.create_partition_inclusive(b_subset)
            self.P_b = self.P_b_base.create_cartesian_topology_partition(
                [P_co] + [1] + list(P_spatial))
            self.receives_bias = self.P_b.active and bias

            # Now find the subset of _that_ which actually stores the learnable parameter.
            b_root_subset = []
            for i, c in enumerate(range_index(P_w.shape)):
                c = np.asarray(c)
                # Find the P_co x 1 x 1 x ... x 1 subset to store the biases
                if np.all(c[1:] == 0):
                    b_root_subset.append(i)

            self.P_br_base = self.P_w.create_partition_inclusive(b_root_subset)
            # ones are needed so the broadcast will work
            self.P_br = self.P_br_base.create_cartesian_topology_partition(
                [P_co] + [1] + [1] * len(P_spatial))
            self.stores_bias = self.P_br.active and bias

            # Correct the input arguments based on local properties
            local_kwargs = {}
            local_kwargs.update(kwargs)

            # Do this before checking serial so that the layer works properly
            # in the serial case
            local_channels = compute_subshape(P_channels, P_w.index[0:2],
                                              [out_channels, in_channels])
            local_out_channels, local_in_channels = local_channels
            local_kwargs["in_channels"] = local_in_channels
            local_kwargs["out_channels"] = local_out_channels

            local_kwargs["bias"] = self.receives_bias
            self.conv_layer = self.TorchConvType(*args, **local_kwargs)

            # If we store the weight it is a learnable parameter iff it is
            # learnable by default in the layer, which it is.
            if self.stores_weight:
                self._weight = torch.nn.Parameter(
                    self.conv_layer.weight.detach())
            else:
                self._weight = zero_volume_tensor()
            # This always exists so we can copy the property
            self._weight.requires_grad = self.conv_layer.weight.requires_grad

            # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2
            new_weight = self.conv_layer.weight.detach() * 0
            new_weight.requires_grad = self.conv_layer.weight.requires_grad
            del self.conv_layer.weight
            self.conv_layer.weight = new_weight

            # If we store the bias, it is a learnable parameter iff it is
            # learnable by default in the layer, which is only true if it
            # exists.
            if self.stores_bias:
                self._bias = torch.nn.Parameter(self.conv_layer.bias.detach())
            else:
                self._bias = zero_volume_tensor()
            # This does not always exist, but when it does we can copy the
            # property.
            if self.receives_bias:
                self._bias.requires_grad = self.conv_layer.bias.requires_grad

                # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2
                new_bias = self.conv_layer.bias.detach() * 0
                new_bias.requires_grad = self.conv_layer.bias.requires_grad
                del self.conv_layer.bias
                self.conv_layer.bias = new_bias

        # Now we need to share the kernel structure.  The size of the kernel
        # is always the spatial dimensions.
        self.conv_kernel_size = None
        self.conv_stride = None
        self.conv_padding = None
        self.conv_dilation = None
        if P_union.rank == 0:
            self.conv_kernel_size = np.array(self.conv_layer.kernel_size,
                                             dtype=np.int)
            self.conv_stride = np.array(self.conv_layer.stride, dtype=np.int)
            self.conv_padding = np.array(self.conv_layer.padding, dtype=np.int)
            self.conv_dilation = np.array(self.conv_layer.dilation,
                                          dtype=np.int)
        self.conv_kernel_size = P_union.broadcast_data(self.conv_kernel_size,
                                                       root=0)
        self.conv_stride = P_union.broadcast_data(self.conv_stride, root=0)
        self.conv_padding = P_union.broadcast_data(self.conv_padding, root=0)
        self.conv_dilation = P_union.broadcast_data(self.conv_dilation, root=0)

        # We need the halo shape, and other info, to fully populate the pad,
        # halo exchange, and unpad layers.  For pad and unpad, we defer their
        # construction to the pre-forward hook.

        self.pad_layer = None
        self.unpad_layer = None

        # We need to be able to remove some data from the input to the conv
        # layer.
        self.needed_slices = None

        # For the halo layer we also defer construction, so that we can have
        # the halo shape for the input.  The halo will allocate its own
        # buffers, but it needs this information at construction to be able
        # to do this in the pre-forward hook.

        self.halo_layer = None

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_shape = None
        self._input_requires_grad = None

        if P_w.active:
            self.w_broadcast = Broadcast(self.P_wr,
                                         self.P_w,
                                         preserve_batch=False)

        if self.receives_bias or self.stores_bias:
            self.b_broadcast = Broadcast(self.P_br,
                                         self.P_b,
                                         preserve_batch=False)

        self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True)
        self.y_sum_reduce = SumReduce(self.P_w, self.P_y, preserve_batch=True)
Exemple #27
0
    def forward(ctx, input, P_allreduce, input_tensor_structure,
                output_tensor_structure):
        r"""Forward function of distributed all-sum-reduction layer.

        This method implements the forward all-sum-reduction operation using the
        ``MPI_Iallreduce`` function on the communicator defined by ``P_allreduce``.

        When the current worker is inactive in the ``P_allreduce`` partition, it will
        output a zero-volume tensor.

        Parameters
        ----------
        ctx :
            PyTorch context.
        input : `torch.tensor`
            Input tensor.
        P_allreduce : Partition
            Partition reduction happens within.
        input_tensor_structure : tuple
            Tuple containing properties of the input tensor (dimension, shape,
            requires_grad).
        output_tensor_structure : tuple
            Tuple containing properties of the output tensor (dimension, shape,
            requires_grad).

        Returns
        -------
        output :
            Output tensor.

        """

        device = input.device
        ctx.P_allreduce = P_allreduce
        ctx.input_tensor_structure = input_tensor_structure
        ctx.output_tensor_structure = output_tensor_structure
        ctx.device = device

        output = zero_volume_tensor(device=device)

        requests = []

        # There is no need to specificy a root.
        if P_allreduce.active:
            numpy_dtype = torch_to_numpy_dtype_dict[
                input_tensor_structure.dtype]

            reduced_data = np.zeros(input_tensor_structure.shape,
                                    dtype=numpy_dtype)
            input_numpy = input.detach().cpu().numpy()
            req = P_allreduce._comm.Iallreduce(input_numpy,
                                               reduced_data,
                                               op=MPI.SUM)
            requests.append(req)

        MPI.Request.Waitall(requests)

        # If we had to receive data, we need to tensorify it.
        if P_allreduce.active:
            output = torch.tensor(
                reduced_data,
                requires_grad=output_tensor_structure.requires_grad,
                device=device)

        return output
Exemple #28
0
    def forward(ctx, input, P_union, x_global_shape, P_x, in_data, in_buffers,
                P_y, out_data, out_buffers, preserve_batch, dtype):

        ctx.P_union = P_union
        ctx.x_global_shape = x_global_shape

        ctx.P_x = P_x
        ctx.in_data = in_data
        ctx.in_buffers = in_buffers

        ctx.P_y = P_y
        ctx.out_data = out_data
        ctx.out_buffers = out_buffers

        ctx.preserve_batch = preserve_batch

        ctx.dtype = dtype

        input_requires_grad = False
        # By design, P_x is always first in the union
        if P_union.active:
            if P_x.rank == 0:
                input_requires_grad = input.requires_grad
                P_union.comm.Bcast(np.array([1 if input_requires_grad else 0]),
                                   root=0)
            else:
                irg = np.array([0], dtype=np.int)
                P_union.comm.Bcast(irg, root=0)
                input_requires_grad = bool(irg[0] == 1)

        ctx.input_requires_grad = input_requires_grad

        requests = []

        # Default everyone to output nothing
        if preserve_batch:
            output = zero_volume_tensor(input.shape[0])
        else:
            output = zero_volume_tensor()

        # If I am getting data, recv my output parts
        recv_count = 0
        if P_y.active:
            for (sl, sz, partner), buff in zip(out_data, out_buffers):
                if buff is not None:
                    req = P_union.comm.Irecv(buff, source=partner, tag=111)
                    requests.append(req)
                else:
                    # We add this if there is no recv so that the indices of
                    # the requests array match the indices of out_data and
                    # out_buffers.
                    requests.append(MPI.REQUEST_NULL)
                recv_count += 1

        # If I have data to share, pack and send my input parts
        send_count = 0
        if P_x.active:
            input_numpy = input.detach().numpy()
            for (sl, sz, partner), buff in zip(in_data, in_buffers):
                if buff is not None:
                    np.copyto(buff, input_numpy[tuple(sl)].ravel())
                    req = P_union.comm.Isend(buff, dest=partner, tag=111)
                    requests.append(req)
                else:
                    # We add this for symmetry, but don't really need it.
                    requests.append(MPI.REQUEST_NULL)
                send_count += 1

        # We do this after the sends so that they can get started before local
        # allocations.
        if P_y.active:
            index = P_y.index
            y_local_shape = compute_subshape(P_y.shape, index, x_global_shape)
            # TODO(#25): The dtype should not be fixed, but correcting this is
            #            a thing that needs to be resolved globally.
            output = np.zeros(y_local_shape, dtype=dtype)

        # Unpack the received data as it arrives
        completed_count = 0
        while (completed_count < len(requests)):
            status = MPI.Status()
            index = MPI.Request.Waitany(requests, status)

            # In MPI, we don't get the index out if the request is an
            # instance of MPI.REQUEST_NULL, instead MPI.UNDEFINED is returned.
            if P_y.active and index < recv_count and index != MPI.UNDEFINED:
                # Unpack my output parts
                sl, sz, partner = out_data[index]
                buff = out_buffers[index]
                if buff is not None:
                    sh = output[tuple(sl)].shape
                    np.copyto(output[tuple(sl)], buff.reshape(sh))

            completed_count += 1

        if P_y.active:
            output = torch.from_numpy(output)
            output.requires_grad = input_requires_grad

        return output
Exemple #29
0
def test_conv_versus_pytorch(barrier_fence_fixture, comm_split_fixture,
                             P_x_ranks, P_x_shape, input_dimensions,
                             x_global_shape, kernel_size, padding, stride,
                             dilation, bias):

    import numpy as np
    import torch
    from torch.nn import Conv1d
    from torch.nn import Conv2d
    from torch.nn import Conv3d

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.conv_feature import DistributedFeatureConv1d
    from distdl.nn.conv_feature import DistributedFeatureConv2d
    from distdl.nn.conv_feature import DistributedFeatureConv3d
    from distdl.nn.repartition import Repartition
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)
    P_world._comm.Barrier()

    # Create the partitions
    P_0_base = P_world.create_partition_inclusive(np.arange(1))
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_0 = P_0_base.create_cartesian_topology_partition([1] * len(P_x_shape))
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    scatter_layer_x = Repartition(P_0, P_x)
    scatter_layer_x = scatter_layer_x.to(device)
    scatter_layer_y = Repartition(P_0, P_x)
    scatter_layer_y = scatter_layer_y.to(device)
    gather_layer_x = Repartition(P_x, P_0)
    gather_layer_x = gather_layer_x.to(device)
    gather_layer_y = Repartition(P_x, P_0)
    gather_layer_y = gather_layer_y.to(device)

    # Create the layers
    if input_dimensions == 1:
        dist_layer_type = DistributedFeatureConv1d
        seq_layer_type = Conv1d
    elif input_dimensions == 2:
        dist_layer_type = DistributedFeatureConv2d
        seq_layer_type = Conv2d
    elif input_dimensions == 3:
        dist_layer_type = DistributedFeatureConv3d
        seq_layer_type = Conv3d
    dist_layer = dist_layer_type(P_x,
                                 in_channels=x_global_shape[1],
                                 out_channels=10,
                                 kernel_size=kernel_size,
                                 padding=padding,
                                 stride=stride,
                                 dilation=dilation,
                                 bias=bias)
    dist_layer = dist_layer.to(device)
    if P_0.active:
        seq_layer = seq_layer_type(in_channels=x_global_shape[1],
                                   out_channels=10,
                                   kernel_size=kernel_size,
                                   padding=padding,
                                   stride=stride,
                                   dilation=dilation,
                                   bias=bias)
        # set the weights of both layers to be the same
        seq_layer = seq_layer.to(device)
        weight = torch.rand_like(seq_layer.weight, device=device)
        seq_layer.weight.data = weight
        dist_layer.weight.data = weight
        if bias:
            bias_weight = torch.rand_like(seq_layer.bias, device=device)
            seq_layer.bias.data = bias_weight
            dist_layer.bias.data = bias_weight

    # Forward Input
    x_ref = zero_volume_tensor(device=device)
    x_ref.requires_grad = True
    dy_ref = zero_volume_tensor(device=device)

    # Construct the inputs to the forward and backward functions as well as the
    # the outputs of the sequential layer
    if P_0.active:
        x_ref = torch.randn(*x_global_shape, device=device)
        x_ref.requires_grad = True
        y_ref = seq_layer(x_ref)
        y_global_shape_calc = y_ref.shape

        dy_ref = torch.randn(*y_global_shape_calc, device=device)

        y_ref.backward(dy_ref)
        dx_ref = x_ref.grad

    # Ensure that the scatter is not part of the computation we are testing
    with torch.no_grad():
        x = scatter_layer_x(x_ref.detach())
        dy = scatter_layer_y(dy_ref.detach())

    x.requires_grad = True

    y = dist_layer(x)
    y.backward(dy)
    dx = x.grad

    # Ensure that the gather is not part of the computation we are testing
    with torch.no_grad():
        dx_comp = gather_layer_x(dx.detach())
        y_comp = gather_layer_y(y.detach())

    if P_0.active:

        # Set the absolute tolerance to ~sqrt(e_mach), or the default
        # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit
        # floats, not 32-bit floats as torch does.  Consequently, the default
        # torch atol is actually tighter than one can expect from two fp-equal
        # floating point numbers.  The NumPy default of 1e-8 is closer to
        # sqrt(e_mach) for 64-bit numbers.  So we set the 32-bit tolerance to
        # a little tighter than sqrt(1e-7), 1e-5.
        if x_ref.dtype == torch.float64:
            atol = 1e-8
        elif x_ref.dtype == torch.float32:
            atol = 1e-5
        else:
            # torch default
            atol = 1e-8

        # Test the result of each entry independently
        assert torch.allclose(y_ref, y_comp, atol=atol)
        assert torch.allclose(dx_ref, dx_comp, atol=atol)

    P_world.deactivate()
    P_0_base.deactivate()
    P_0.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Exemple #30
0
    def backward(ctx, grad_output):

        P_union = ctx.P_union
        x_global_shape = ctx.x_global_shape

        P_x = ctx.P_x
        in_data = ctx.in_data
        in_buffers = ctx.in_buffers

        P_y = ctx.P_y
        out_data = ctx.out_data
        out_buffers = ctx.out_buffers

        preserve_batch = ctx.preserve_batch

        dtype = ctx.dtype

        input_requires_grad = ctx.input_requires_grad

        requests = []

        # Default everyone to output None
        if preserve_batch:
            grad_input = zero_volume_tensor(grad_output.shape[0])
        else:
            grad_input = zero_volume_tensor()

        # Recv my input parts
        recv_count = 0
        if P_x.active:
            for (sl, sz, partner), buff in zip(in_data, in_buffers):
                if buff is not None:
                    req = P_union.comm.Irecv(buff, source=partner, tag=113)
                    requests.append(req)
                else:
                    # We add this if there is no recv so that the indices of
                    # the requests array match the indices of in_data and
                    # in_buffers.
                    requests.append(MPI.REQUEST_NULL)
                recv_count += 1

        # Pack and send my input parts
        send_count = 0
        if P_y.active:
            grad_output_numpy = grad_output.detach().numpy()
            for (sl, sz, partner), buff in zip(out_data, out_buffers):
                if buff is not None:
                    np.copyto(buff, grad_output_numpy[tuple(sl)].ravel())
                    req = P_union.comm.Isend(buff, dest=partner, tag=113)
                    requests.append(req)
                else:
                    # We add this for symmetry, but don't really need it.
                    requests.append(MPI.REQUEST_NULL)
                send_count += 1

        if P_x.active:
            index = P_x.index
            x_local_shape = compute_subshape(P_x.shape, index, x_global_shape)
            # TODO(#25): The dtype should not be fixed, but correcting this is
            #            a thing that needs to be resolved globally.
            grad_input = np.zeros(x_local_shape, dtype=dtype)

        # Unpack the received data as it arrives
        completed_count = 0
        while (completed_count < len(requests)):
            status = MPI.Status()
            index = MPI.Request.Waitany(requests, status)

            # In MPI, we don't get the index out if the request is an
            # instance of MPI.REQUEST_NULL, instead MPI.UNDEFINED is returned.
            if P_x.active and index < recv_count and index != MPI.UNDEFINED:
                # Unpack my output parts
                sl, sz, partner = in_data[index]
                buff = in_buffers[index]
                if buff is not None:
                    sh = grad_input[tuple(sl)].shape
                    # This would normally be an add into the grad_input tensor
                    # but we just created it, so a copy is sufficient.
                    np.copyto(grad_input[tuple(sl)], buff.reshape(sh))

            completed_count += 1

        if P_x.active:
            grad_input = torch.from_numpy(grad_input)
            grad_input.requires_grad = input_requires_grad

        return grad_input, None, None, None, None, None, None, None, None, None, None