Exemplo n.º 1
0
def test_potentially_deadlocked_send_recv_pairs(barrier_fence_fixture,
                                                comm_split_fixture, P_x_ranks,
                                                P_x_shape, P_w_ranks,
                                                P_w_shape):

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.broadcast import Broadcast

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_w_base = P_world.create_partition_inclusive(P_w_ranks)
    P_w = P_w_base.create_cartesian_topology_partition(P_w_shape)

    layer = Broadcast(P_x, P_w)  # noqa F841
    layer = layer.to(device)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_w_base.deactivate()
    P_w.deactivate()
Exemplo n.º 2
0
def test_simple_conv2d_adjoint_weight(barrier_fence_fixture,
                                      comm_split_fixture,
                                      P_x_ranks, P_x_shape,
                                      x_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.conv_feature import DistributedFeatureConv2d
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    x_global_shape = np.asarray(x_global_shape)

    layer = DistributedFeatureConv2d(P_x,
                                     in_channels=x_global_shape[1],
                                     out_channels=10,
                                     kernel_size=[3, 3], bias=False)

    x = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape,
                                         P_x.index,
                                         x_global_shape)
        x = torch.randn(*x_local_shape)
    x.requires_grad = True

    y = layer(x)

    dy = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        dy = torch.randn(*y.shape)

    y.backward(dy)

    W = zero_volume_tensor()
    dW = zero_volume_tensor()
    if P_x.active:
        W = layer.weight.detach()
        dW = layer.weight.grad.detach()

    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, W, dW, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Exemplo n.º 3
0
def test_excepts_mismatched_nondivisible_tensor(barrier_fence_fixture,
                                                comm_split_fixture):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.repartition import Repartition
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # A tensor with size 1 in a dimension cannot be partitioned in that
    # dimension.  (See last dimension of output and tensor.)
    in_shape = (1, 4, 1, 1)
    out_shape = (1, 1, 1, 2)
    x_global_shape = np.array([1, 16, 5, 1])

    in_size = np.prod(in_shape)
    out_size = np.prod(out_shape)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size))
    P_x = P_x_base.create_cartesian_topology_partition(in_shape)

    P_y_base = P_world.create_partition_inclusive(
        np.arange(P_world.size - out_size, P_world.size))
    P_y = P_y_base.create_cartesian_topology_partition(out_shape)

    with pytest.raises(ValueError) as e_info:  # noqa: F841
        layer = Repartition(P_x, P_y)
        layer = layer.to(device)

        # Forward Input
        x = zero_volume_tensor(device=device)
        if P_x.active:
            x_local_shape = compute_subshape(P_x.shape, P_x.index,
                                             x_global_shape)
            x = torch.randn(*x_local_shape, device=device)
        x.requires_grad = True

        layer(x)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
Exemplo n.º 4
0
def test_conv_class_selection(barrier_fence_fixture, comm_split_fixture,
                              P_x_ranks, P_x_shape, P_y_ranks, P_y_shape,
                              P_w_ranks, P_w_shape, kernel_size,
                              InputLayerType, OutputLayerType):

    from distdl.backends.mpi.partition import MPIPartition

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    if P_y_ranks is not None:
        P_y_base = P_world.create_partition_inclusive(P_y_ranks)
        P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)
    else:
        P_y_base = None
        P_y = None

    if P_w_ranks is not None:
        P_w_base = P_world.create_partition_inclusive(P_w_ranks)
        P_w = P_w_base.create_cartesian_topology_partition(P_w_shape)
    else:
        P_w_base = None
        P_w = None

    layer = InputLayerType(P_x,
                           P_y=P_y,
                           P_w=P_w,
                           in_channels=3,
                           out_channels=3,
                           kernel_size=kernel_size)

    assert type(layer) == OutputLayerType

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()

    if P_y_ranks is not None:
        P_y_base.deactivate()
        P_y.deactivate()

    if P_w_ranks is not None:
        P_w_base.deactivate()
        P_w.deactivate()
Exemplo n.º 5
0
def test_simple_conv2d_shape(barrier_fence_fixture,
                             comm_split_fixture,
                             P_x_ranks, P_x_shape,
                             x_global_shape,
                             y_local_shape,
                             padding):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.conv_feature import DistributedFeatureConv2d
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    x_global_shape = np.asarray(x_global_shape)

    layer = DistributedFeatureConv2d(P_x,
                                     in_channels=x_global_shape[1],
                                     out_channels=10,
                                     kernel_size=[3, 3],
                                     padding=padding,
                                     bias=False)

    x = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape,
                                         P_x.index,
                                         x_global_shape)
        x = torch.zeros(*x_local_shape)
    x.requires_grad = True

    y = layer(x)

    if P_x.active:
        assert(np.array_equal(np.array(y.shape), np.asarray(y_local_shape)))

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Exemplo n.º 6
0
def test_excepts_no_match(barrier_fence_fixture, comm_split_fixture):

    import numpy as np

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn import DistributedConv2d

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)
    P2 = P_world.create_partition_inclusive(np.arange(2))
    P4 = P_world.create_partition_inclusive(np.arange(4))

    ks = [3, 3]

    P_x = P2.create_cartesian_topology_partition([1, 2, 1, 1])
    P_y = P2.create_cartesian_topology_partition([1, 2, 1, 1])
    P_w = P4.create_cartesian_topology_partition([2, 2, 1, 1])

    with pytest.raises(ValueError) as e_info:  # noqa: F841
        layer = DistributedConv2d(
            P_x,
            P_y=P_y,  # noqa: F841
            in_channels=3,
            out_channels=3,
            kernel_size=ks)

    with pytest.raises(ValueError) as e_info:  # noqa: F841
        layer = DistributedConv2d(
            P_x,
            P_w=P_w,  # noqa: F841
            in_channels=3,
            out_channels=3,
            kernel_size=ks)

    P_world.deactivate()
    P2.deactivate()
    P4.deactivate()
    P_x.deactivate()
    P_y.deactivate()
    P_w.deactivate()
Exemplo n.º 7
0
def test_excepts_mismatched_partitions(barrier_fence_fixture,
                                       comm_split_fixture):

    import numpy as np

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.repartition import Repartition

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    in_shape = (1, 4, 1, 1)
    out_shape = (1, 2)

    in_size = np.prod(in_shape)
    out_size = np.prod(out_shape)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size))
    P_x = P_x_base.create_cartesian_topology_partition(in_shape)

    P_y_base = P_world.create_partition_inclusive(
        np.arange(P_world.size - out_size, P_world.size))
    P_y = P_y_base.create_cartesian_topology_partition(out_shape)

    with pytest.raises(ValueError) as e_info:  # noqa: F841
        Repartition(P_x, P_y)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
Exemplo n.º 8
0
def test_distributed_loss(barrier_fence_fixture, comm_split_fixture, P_x_ranks,
                          P_x_shape, x_global_shape, SequentialLoss,
                          DistributedLoss):

    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn import DistributedTranspose
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_0_base = P_x_base.create_partition_inclusive([0])
    P_0 = P_0_base.create_cartesian_topology_partition([1] * len(P_x_shape))

    scatter = DistributedTranspose(P_0, P_x).to(device)
    gather = DistributedTranspose(P_x, P_0).to(device)

    for reduction in DistributedLoss._valid_reductions:

        distributed_criterion = DistributedLoss(P_x,
                                                reduction=reduction).to(device)
        sequential_criterion = SequentialLoss(reduction=reduction).to(device)

        with torch.no_grad():
            x_g = zero_volume_tensor(device=device)
            y_g = zero_volume_tensor(device=device)
            if P_0.active:
                x_g = torch.rand(x_global_shape, device=device)
                y_g = torch.rand(x_global_shape, device=device)

            x_l = scatter(x_g)
            y_l = scatter(y_g)

        x_l.requires_grad = True
        distributed_loss = distributed_criterion(x_l, y_l)

        # For "none", no reduction is applied so we see if it computed the
        # same loss as the sequential code by gathering the loss value it to
        # the root rank.
        if reduction == "none":
            distributed_loss = gather(distributed_loss)

        if P_0.active:
            x_g.requires_grad = True
            sequential_loss = sequential_criterion(x_g, y_g)

            assert (torch.allclose(distributed_loss, sequential_loss))

        # For any other reduction, we can compare the loss
        # value *and* backpropagate through the distributed loss to verify
        # that it produces the same output.
        if reduction != "none":
            distributed_loss.backward()
            distributed_dx_g = gather(x_l.grad)

            if P_0.active:
                sequential_loss.backward()
                sequential_dx_g = x_g.grad

                assert (torch.allclose(distributed_dx_g, sequential_dx_g))

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_0_base.deactivate()
    P_0.deactivate()
Exemplo n.º 9
0
def test_sum_reduce_dtype(barrier_fence_fixture, comm_split_fixture, dtype,
                          test_backward, P_x_ranks, P_x_shape, P_y_ranks,
                          P_y_shape, x_global_shape, transpose_src):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.sum_reduce import SumReduce
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    # TODO #93: Change this to create a subtensor so we test when local tensors
    # have different shape.  Then, the output size will also be different, which
    # we will have to get from `y` itself.
    x_local_shape = np.asarray(x_global_shape)

    layer = SumReduce(P_x,
                      P_y,
                      transpose_src=transpose_src,
                      preserve_batch=False)
    layer = layer.to(device)

    x = zero_volume_tensor(device=device)
    if P_x.active:
        x = 10 * torch.randn(*x_local_shape, device=device).to(dtype)

    x.requires_grad = test_backward

    # y = F @ x
    y = layer(x)

    # If we are not in the output partition, there is no data to test the type
    # against.
    if P_y.active:
        assert y.dtype == dtype

    if test_backward:
        dy = zero_volume_tensor(device=device)
        if P_y.active:
            # Adjoint Input
            dy = 10 * torch.randn(*x_local_shape, device=device).to(dtype)

        # dx = F* @ dy
        y.backward(dy)
        dx = x.grad

        if P_x.active:
            assert dx.dtype == dtype

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
Exemplo n.º 10
0
def test_matches_sequential(barrier_fence_fixture, comm_split_fixture,
                            P_x_ranks, P_x_shape, input_dimensions,
                            x_global_shape, kernel_size, padding, stride,
                            dilation, layer_type):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.repartition import Repartition
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)
    P_world._comm.Barrier()

    # Create the partitions
    P_0_base = P_world.create_partition_inclusive(np.arange(1))
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_0 = P_0_base.create_cartesian_topology_partition([1] * len(P_x_shape))
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    scatter_layer_x = Repartition(P_0, P_x).to(device)
    scatter_layer_y = Repartition(P_0, P_x).to(device)
    gather_layer_x = Repartition(P_x, P_0).to(device)
    gather_layer_y = Repartition(P_x, P_0).to(device)

    # Create the layers
    if input_dimensions == 1:
        if layer_type == 'max':
            from torch.nn import MaxPool1d as SequentialPoolType

            from distdl.nn import DistributedMaxPool1d as DistributedPoolType
        else:
            from torch.nn import AvgPool1d as SequentialPoolType

            from distdl.nn import DistributedAvgPool1d as DistributedPoolType
    elif input_dimensions == 2:
        if layer_type == 'max':
            from torch.nn import MaxPool2d as SequentialPoolType

            from distdl.nn import DistributedMaxPool2d as DistributedPoolType
        else:
            from torch.nn import AvgPool2d as SequentialPoolType

            from distdl.nn import DistributedAvgPool2d as DistributedPoolType
    elif input_dimensions == 3:
        if layer_type == 'max':
            from torch.nn import MaxPool3d as SequentialPoolType

            from distdl.nn import DistributedMaxPool3d as DistributedPoolType
        else:
            from torch.nn import AvgPool3d as SequentialPoolType

            from distdl.nn import DistributedAvgPool3d as DistributedPoolType

    # PyTorch AvgPool doesn't support dilation, so skip the test if the combination comes up
    dilation_is_default = dilation == 1 or all(x == 1 for x in dilation)
    if layer_type == 'avg' and not dilation_is_default:
        return

    layer_kwargs = {
        'kernel_size': kernel_size,
        'padding': padding,
        'stride': stride
    }

    # Only max pool layers support dilation
    if layer_type == 'max':
        layer_kwargs['dilation'] = dilation

    dist_layer = DistributedPoolType(P_x, **layer_kwargs).to(device)
    if P_0.active:
        seq_layer = SequentialPoolType(**layer_kwargs).to(device)

    # Forward Input
    x_ref = zero_volume_tensor(device=device)
    x_ref.requires_grad = True
    dy_ref = zero_volume_tensor(device=device)

    # Construct the inputs to the forward and backward functions as well as the
    # the outputs of the sequential layer
    if P_0.active:
        x_ref = torch.randn(*x_global_shape, device=device)
        x_ref.requires_grad = True
        y_ref = seq_layer(x_ref)
        y_global_shape_calc = y_ref.shape

        dy_ref = torch.randn(*y_global_shape_calc, device=device)

        y_ref.backward(dy_ref)
        dx_ref = x_ref.grad

    # Ensure that the scatter is not part of the computation we are testing
    with torch.no_grad():
        x = scatter_layer_x(x_ref.detach())
        dy = scatter_layer_y(dy_ref.detach())

    x.requires_grad = True

    y = dist_layer(x)
    y.backward(dy)
    dx = x.grad

    # Ensure that the gather is not part of the computation we are testing
    with torch.no_grad():
        dx_comp = gather_layer_x(dx.detach())
        y_comp = gather_layer_y(y.detach())

    if P_0.active:

        # Set the absolute tolerance to ~sqrt(e_mach), or the default
        # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit
        # floats, not 32-bit floats as torch does.  Consequently, the default
        # torch atol is actually tighter than one can expect from two fp-equal
        # floating point numbers.  The NumPy default of 1e-8 is closer to
        # sqrt(e_mach) for 64-bit numbers.  So we set the 32-bit tolerance to
        # a little tighter than sqrt(1e-7), 1e-5.
        if x_ref.dtype == torch.float64:
            atol = 1e-8
        elif x_ref.dtype == torch.float32:
            atol = 1e-5
        else:
            # torch default
            atol = 1e-8

        assert torch.allclose(y_ref, y_comp, atol=atol)
        assert torch.allclose(dx_ref, dx_comp, atol=atol)

    P_world.deactivate()
    P_0_base.deactivate()
    P_0.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Exemplo n.º 11
0
def test_upsample_matches_sequential(barrier_fence_fixture, comm_split_fixture,
                                     P_x_ranks, P_x_shape, x_global_shape,
                                     scale_factor, y_global_shape, mode,
                                     align_corners, use_size):

    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.nn.upsampling import DistributedUpsample
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return

    # Test align_corners only in the linear case, otherwise ignore it.
    if mode != "linear" and align_corners:
        return

    torch_mode_map = {3: "linear", 4: "bilinear", 5: "trilinear"}
    torch_mode = mode
    if mode == "linear":
        torch_mode = torch_mode_map[len(x_global_shape)]

    torch_align_corners = align_corners
    if mode == "nearest":
        torch_align_corners = None

    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_0_base = P_world.create_partition_inclusive([0])
    P_0 = P_0_base.create_cartesian_topology_partition([1] * len(P_x_shape))

    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    scatter_layer_x = DistributedTranspose(P_0, P_x)
    scatter_layer_y = DistributedTranspose(P_0, P_x)
    gather_layer_x = DistributedTranspose(P_x, P_0)
    gather_layer_y = DistributedTranspose(P_x, P_0)

    if use_size:
        dist_layer = DistributedUpsample(P_x,
                                         size=y_global_shape,
                                         mode=mode,
                                         align_corners=align_corners)
        if P_0.active:
            seq_layer = torch.nn.Upsample(size=y_global_shape[2:],
                                          mode=torch_mode,
                                          align_corners=torch_align_corners)
    else:
        dist_layer = DistributedUpsample(P_x,
                                         scale_factor=scale_factor,
                                         mode=mode,
                                         align_corners=align_corners)
        if P_0.active:
            seq_layer = torch.nn.Upsample(scale_factor=scale_factor,
                                          mode=torch_mode,
                                          align_corners=torch_align_corners)

    # Forward Input
    x_ref = zero_volume_tensor()
    x_ref.requires_grad = True
    dy_ref = zero_volume_tensor()

    # Construct the inputs to the forward and backward functions as well as the
    # the outputs of the sequential layer
    if P_0.active:
        x_ref = torch.randn(*x_global_shape)
        x_ref.requires_grad = True
        y_ref = seq_layer(x_ref)
        y_global_shape_calc = y_ref.shape

        dy_ref = torch.randn(*y_global_shape_calc)

        y_ref.backward(dy_ref)
        dx_ref = x_ref.grad

    # Ensure that the scatter is not part of the computation we are testing
    with torch.no_grad():
        x = scatter_layer_x(x_ref.detach())
        dy = scatter_layer_y(dy_ref.detach())

    x.requires_grad = True

    # Because  there is no guarantee that any padding is needed, in this test,
    # the input x may pass directly to the Halo layer without going through
    # the padding process.  As the halo layer is in-place, that would mean a leaf-node
    # variable is modified in-place, which PyTorch does not allow.
    #
    # Thus, we have to clone it to make the input not a leaf-node.
    x_clone = x.clone()
    y = dist_layer(x_clone)
    y.backward(dy)
    dx = x.grad

    # Ensure that the gather is not part of the computation we are testing
    with torch.no_grad():
        dx_comp = gather_layer_x(dx.detach())
        y_comp = gather_layer_y(y.detach())

    if P_0.active:

        # Set the absolute tolerance to ~sqrt(e_mach), or the default
        # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit
        # floats, not 32-bit floats as torch does.  Consequently, the default
        # torch atol is actually tighter than one can expect from two fp-equal
        # floating point numbers.  The NumPy default of 1e-8 is closer to
        # sqrt(e_mach) for 64-bit numbers.  So we set the 32-bit tolerance to
        # a little tighter than sqrt(1e-7), 1e-5.
        if x_ref.dtype == torch.float64:
            atol = 1e-8
        elif x_ref.dtype == torch.float32:
            atol = 1e-5
        else:
            # torch default
            atol = 1e-8

        # Test the result of each entry independently
        assert torch.allclose(y_ref, y_comp, atol=atol)
        assert torch.allclose(dx_ref, dx_comp, atol=atol)

    P_world.deactivate()
    P_0_base.deactivate()
    P_0.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Exemplo n.º 12
0
def test_general_conv1d_adjoint_bias(barrier_fence_fixture, comm_split_fixture,
                                     P_x_ranks, P_x_shape, P_y_ranks,
                                     P_y_shape, P_w_ranks, P_w_shape,
                                     x_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.conv_general import DistributedGeneralConv1d
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    P_w_base = P_world.create_partition_inclusive(P_w_ranks)
    P_w = P_w_base.create_cartesian_topology_partition(P_w_shape)

    x_global_shape = np.asarray(x_global_shape)

    layer = DistributedGeneralConv1d(P_x,
                                     P_y,
                                     P_w,
                                     in_channels=x_global_shape[1],
                                     out_channels=10,
                                     kernel_size=[3],
                                     bias=True)

    x = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape)
        x = torch.zeros(*x_local_shape)
    x.requires_grad = True

    y = layer(x)

    dy = zero_volume_tensor(x_global_shape[0])
    if P_y.active:
        dy = torch.randn(*y.shape)

    y.backward(dy)

    b = zero_volume_tensor()
    db = zero_volume_tensor()
    if P_w.active and layer.stores_bias:
        b = layer.bias.detach()
        db = layer.bias.grad.detach()

    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, b, db, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
    P_w_base.deactivate()
    P_w.deactivate()
Exemplo n.º 13
0
def test_buffer_management_mixed_network(barrier_fence_fixture,
                                         comm_split_fixture):

    import numpy as np
    import torch

    import distdl
    from distdl.backends.mpi.buffer import MPIBufferManager
    from distdl.backends.mpi.partition import MPIPartition
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return

    buffer_manager = MPIBufferManager()
    P_world = MPIPartition(base_comm)

    # Create the partitions

    P_1_base = P_world.create_partition_inclusive([0])
    P_1 = P_1_base.create_cartesian_topology_partition([1, 1, 1, 1])

    P_22_base = P_world.create_partition_inclusive([0, 1, 2, 3])
    P_22 = P_22_base.create_cartesian_topology_partition([1, 1, 2, 2])

    tr1 = distdl.nn.DistributedTranspose(P_1,
                                         P_22,
                                         buffer_manager=buffer_manager)
    c1 = distdl.nn.DistributedConv2d(P_22,
                                     in_channels=1,
                                     out_channels=5,
                                     kernel_size=[3, 3],
                                     padding=[1, 1],
                                     bias=False,
                                     buffer_manager=buffer_manager)
    c2 = distdl.nn.DistributedConv2d(P_22,
                                     in_channels=5,
                                     out_channels=10,
                                     kernel_size=[3, 3],
                                     padding=[1, 1],
                                     bias=False,
                                     buffer_manager=buffer_manager)
    c3 = distdl.nn.DistributedConv2d(P_22,
                                     in_channels=10,
                                     out_channels=20,
                                     kernel_size=[3, 3],
                                     padding=[1, 1],
                                     bias=False,
                                     buffer_manager=buffer_manager)
    tr2 = distdl.nn.DistributedTranspose(P_22,
                                         P_1,
                                         buffer_manager=buffer_manager)

    x = zero_volume_tensor(1)
    if P_1.active:
        x = torch.randn(1, 1, 5, 5)
    x.requires_grad = True

    # [[00   01   02   03   04]      [[00   01   02]  [[03   04]
    #  [10   11   12   13   14]       [10   11   12]   [13   14]]
    #  [20   21   22   23   24]   to  [20   21   22]] [[23   24]
    #  [30   31   32   33   34]      [[30   31   32]   [33   34]
    #  [40   41   42   43   44]]      [40   41   42]]  [43   44]]
    x2 = tr1(x)
    n_buffers_by_rank = (3, 1, 1, 1)
    assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[
        P_world.rank]

    # [[00   01   02]  [[03   04]     [[00   01   02]  [[03   04]
    #  [10   11   12]   [13   14]]     [10   11   12]   [13   14]]
    #  [20   21   22]] [[23   24]  to  [20   21   22]] [[23   24]
    # [[30   31   32]   [33   34]     [[30   31   32]   [33   34]
    #  [40   41   42]]  [43   44]]     [40   41   42]]  [43   44]]
    x3 = c1(x2)
    n_buffers_by_rank = (4, 4, 4, 4)
    assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[
        P_world.rank]

    # [[00   01   02]  [[03   04]     [[00   01   02]  [[03   04]
    #  [10   11   12]   [13   14]]     [10   11   12]   [13   14]]
    #  [20   21   22]] [[23   24]  to  [20   21   22]] [[23   24]
    # [[30   31   32]   [33   34]     [[30   31   32]   [33   34]
    #  [40   41   42]]  [43   44]]     [40   41   42]]  [43   44]]
    x4 = c2(x3)
    n_buffers_by_rank = (4, 4, 4, 4)
    assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[
        P_world.rank]

    # [[00   01   02]  [[03   04]     [[00   01   02]  [[03   04]
    #  [10   11   12]   [13   14]]     [10   11   12]   [13   14]]
    #  [20   21   22]] [[23   24]  to  [20   21   22]] [[23   24]
    # [[30   31   32]   [33   34]     [[30   31   32]   [33   34]
    #  [40   41   42]]  [43   44]]     [40   41   42]]  [43   44]]
    x5 = c3(x4)
    n_buffers_by_rank = (4, 4, 4, 4)
    assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[
        P_world.rank]

    # [[00   01   02]  [[03   04]     [[00   01   02   03   04]
    #  [10   11   12]   [13   14]]     [10   11   12   13   14]
    #  [20   21   22]] [[23   24]  to  [20   21   22   23   24]
    # [[30   31   32]   [33   34]      [30   31   32   33   34]
    #  [40   41   42]]  [43   44]]     [40   41   42   43   44]]
    y = tr2(x5)
    n_buffers_by_rank = (4, 4, 4, 4)
    assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[
        P_world.rank]

    dy = zero_volume_tensor(1)
    if P_1.active:
        dy = torch.randn(1, 20, 5, 5)
    dy.requires_grad = True

    y.backward(dy)
    dx = x.grad

    # Through the backward call the buffer count do not change
    n_buffers_by_rank = (4, 4, 4, 4)
    assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[
        P_world.rank]

    # And adjointness is still preserved

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)

    P_world.deactivate()
    P_1_base.deactivate()
    P_1.deactivate()
    P_22_base.deactivate()
    P_22.deactivate()
Exemplo n.º 14
0
def test_linear_adjoint_input(barrier_fence_fixture,
                              comm_split_fixture,
                              P_x_ranks, P_x_shape,
                              P_y_ranks, P_y_shape,
                              P_w_ranks, P_w_shape,
                              x_global_shape,
                              y_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.linear import DistributedLinear
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    P_w_base = P_world.create_partition_inclusive(P_w_ranks)
    P_w = P_w_base.create_cartesian_topology_partition(P_w_shape)

    x_global_shape = np.asarray(x_global_shape)
    y_global_shape = np.asarray(y_global_shape)

    layer = DistributedLinear(P_x, P_y, P_w,
                              x_global_shape[1],
                              y_global_shape[1],
                              bias=False)
    layer = layer.to(device)

    x = zero_volume_tensor(x_global_shape[0], device=device)
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape,
                                         P_x.index,
                                         x_global_shape)
        x = torch.randn(*x_local_shape, device=device)
    x.requires_grad = True

    y = layer(x)

    dy = zero_volume_tensor(x_global_shape[0], device=device)
    if P_y.active:
        dy = torch.randn(*y.shape, device=device)

    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
    P_w_base.deactivate()
    P_w.deactivate()
Exemplo n.º 15
0
def test_mixin():

    P_world = MPIPartition(MPI.COMM_WORLD)
    ranks = np.arange(P_world.size)

    shape = [1, 1, 4]
    P_size = np.prod(shape)
    use_ranks = ranks[:P_size]

    P_x_base = P_world.create_partition_inclusive(use_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(shape)
    rank = P_x.rank

    layer = MockPoolLayer()

    x_global_shape = np.array([1, 1, 10])
    kernel_size = np.array([2])
    stride = np.array([2])
    padding = np.array([0])
    dilation = np.array([1])

    halo_shape, recv_buffer_shape, send_buffer_shape, needed_ranges = \
        layer._compute_exchange_info(x_global_shape,
                                     kernel_size,
                                     stride,
                                     padding,
                                     dilation,
                                     P_x.active,
                                     P_x.shape,
                                     P_x.index)

    if P_x.active:
        if rank == 0:
            expected_halo_shape = np.array([[0, 0], [0, 0], [0, 1]])
            expected_recv_buffer_shape = np.array([[0, 0], [0, 0], [0, 1]])
            expected_send_buffer_shape = np.array([[0, 0], [0, 0], [0, 0]])
            expected_needed_ranges = np.array([[0, 1], [0, 1], [0, 4]])

            assert (np.array_equal(halo_shape, expected_halo_shape))
            assert (np.array_equal(recv_buffer_shape,
                                   expected_recv_buffer_shape))
            assert (np.array_equal(send_buffer_shape,
                                   expected_send_buffer_shape))
            assert (np.array_equal(needed_ranges, expected_needed_ranges))

        elif rank == 1:
            expected_halo_shape = np.array([[0, 0], [0, 0], [0, 0]])
            expected_recv_buffer_shape = np.array([[0, 0], [0, 0], [0, 0]])
            expected_send_buffer_shape = np.array([[0, 0], [0, 0], [1, 0]])
            expected_needed_ranges = np.array([[0, 1], [0, 1], [1, 3]])

            assert (np.array_equal(halo_shape, expected_halo_shape))
            assert (np.array_equal(recv_buffer_shape,
                                   expected_recv_buffer_shape))
            assert (np.array_equal(send_buffer_shape,
                                   expected_send_buffer_shape))
            assert (np.array_equal(needed_ranges, expected_needed_ranges))

        elif rank == 2:
            expected_halo_shape = np.array([[0, 0], [0, 0], [0, 0]])
            expected_recv_buffer_shape = np.array([[0, 0], [0, 0], [0, 0]])
            expected_send_buffer_shape = np.array([[0, 0], [0, 0], [0, 0]])
            expected_needed_ranges = np.array([[0, 1], [0, 1], [0, 2]])

            assert (np.array_equal(halo_shape, expected_halo_shape))
            assert (np.array_equal(recv_buffer_shape,
                                   expected_recv_buffer_shape))
            assert (np.array_equal(send_buffer_shape,
                                   expected_send_buffer_shape))
            assert (np.array_equal(needed_ranges, expected_needed_ranges))

        elif rank == 3:
            expected_halo_shape = np.array([[0, 0], [0, 0], [0, 0]])
            expected_recv_buffer_shape = np.array([[0, 0], [0, 0], [0, 0]])
            expected_send_buffer_shape = np.array([[0, 0], [0, 0], [0, 0]])
            expected_needed_ranges = np.array([[0, 1], [0, 1], [0, 2]])

            assert (np.array_equal(halo_shape, expected_halo_shape))
            assert (np.array_equal(recv_buffer_shape,
                                   expected_recv_buffer_shape))
            assert (np.array_equal(send_buffer_shape,
                                   expected_send_buffer_shape))
            assert (np.array_equal(needed_ranges, expected_needed_ranges))

    # Inactive ranks should get null results
    else:
        assert (halo_shape is None)
        assert (recv_buffer_shape is None)
        assert (send_buffer_shape is None)
        assert (needed_ranges is None)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Exemplo n.º 16
0
def test_batch_norm_with_training(barrier_fence_fixture, P_x_ranks, P_x_shape,
                                  input_shape, num_features, eps, momentum,
                                  affine, track_running_stats, affine_workers,
                                  comm_split_fixture):

    from distdl.backends.mpi.partition import MPIPartition

    torch.manual_seed(0)

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    num_dimensions = len(input_shape)
    P_in_out_base = P_world.create_partition_inclusive([0])
    P_in_out = P_in_out_base.create_cartesian_topology_partition(
        [1] * num_dimensions)
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)
    P_affine = P_world.create_partition_inclusive(affine_workers)

    # Create the input
    if P_world.rank == 0:
        input_train = torch.rand(input_shape,
                                 dtype=torch.float32,
                                 device=device)
        input_eval = torch.rand(input_shape,
                                dtype=torch.float32,
                                device=device)
        exp = torch.rand(input_shape, dtype=torch.float32, device=device)
    else:
        input_train = zero_volume_tensor(device=device)
        input_eval = zero_volume_tensor(device=device)
        exp = zero_volume_tensor(device=device)

    # Create the sequential network
    if len(input_shape) == 2:
        seq_layer = torch.nn.BatchNorm1d
    elif len(input_shape) == 3:
        seq_layer = torch.nn.BatchNorm1d
    elif len(input_shape) == 4:
        seq_layer = torch.nn.BatchNorm2d
    elif len(input_shape) == 5:
        seq_layer = torch.nn.BatchNorm3d
    if P_world.rank == 0:
        seq_bn = seq_layer(num_features=num_features,
                           eps=eps,
                           momentum=momentum,
                           affine=affine,
                           track_running_stats=track_running_stats)
        seq_bn = seq_bn.to(device)

    # Train sequential network
    if P_world.rank == 0:
        seq_bn.train()
        seq_out1 = seq_bn(input_train)
        seq_loss = ((seq_out1 - exp)**2).sum()
        seq_loss.backward()
        seq_grads = [p.grad.detach().cpu() for p in seq_bn.parameters()]
        # Do a manual weight update (this is what optimizer does):
        with torch.no_grad():
            for p in seq_bn.parameters():
                p.copy_(p + 0.1 * p.grad)

    # Evaluate sequential network
    if P_world.rank == 0:
        seq_bn.eval()
        seq_out2 = seq_bn(input_eval)
        seq_out2 = seq_out2.detach().cpu()

    # Create distributed network
    tr1 = distdl.nn.Repartition(P_in_out, P_x)
    tr1 = tr1.to(device)
    dist_bn = distdl.nn.DistributedBatchNorm(
        P_x,
        num_features=num_features,
        eps=eps,
        momentum=momentum,
        affine=affine,
        track_running_stats=track_running_stats)
    dist_bn = dist_bn.to(device)
    tr2 = distdl.nn.Repartition(P_x, P_in_out)
    tr2 = tr2.to(device)

    # Only rank 0 should have trainable parameters:
    if P_world.rank in affine_workers:
        assert len(list(dist_bn.parameters())) == 2
    else:
        assert len(list(dist_bn.parameters())) == 0

    # Train distributed network
    dist_bn.train()
    dist_out1 = tr2(dist_bn(tr1(input_train)))
    dist_loss = ((dist_out1 - exp)**2).sum()
    assert dist_loss.requires_grad
    dist_loss.backward()
    # Note: We expect the batch norm gradient to have extra dimensions than PyTorch,
    #       but both ultimately have volume equal to num_features.
    #       So, reshape them now, and gather them onto rank 0 for comparison.
    if P_world.rank == 0:
        dist_grads = []
    if P_world.rank in affine_workers:
        for p in dist_bn.parameters():
            if affine_workers == [0]:
                parts = [p.grad.detach().cpu()]
            else:
                parts = P_affine._comm.gather(p.grad.detach().cpu(), root=0)
            if P_world.rank == 0:
                grad = torch.cat(parts, 1)
                reshaped = grad.reshape((num_features, ))
                dist_grads.append(reshaped)
    # Do a manual weight update (this is what optimizer does):
    with torch.no_grad():
        for p in dist_bn.parameters():
            p.copy_(p + 0.1 * p.grad)

    # Evaluate distributed network
    dist_bn.eval()
    dist_out2 = tr2(dist_bn(tr1(input_eval)))
    dist_out2 = dist_out2.detach().cpu()

    # Compare the distributed and sequential networks
    if P_world.rank == 0:

        # Set the absolute tolerance to ~sqrt(e_mach), or the default
        # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit
        # floats, not 32-bit floats as torch does.  Consequently, the default
        # torch atol is actually tighter than one can expect from two fp-equal
        # floating point numbers.  The NumPy default of 1e-8 is closer to
        # sqrt(e_mach) for 64-bit numbers.  So we set the 32-bit tolerance to
        # sqrt(1e-7), as our usual choice, 1e-5, is too tight.
        if seq_out1.dtype == torch.float64:
            atol = 1e-8
        elif seq_out1.dtype == torch.float32:
            import math
            atol = math.sqrt(1e-7)
        else:
            # torch default
            atol = 1e-8

        assert dist_out1.shape == seq_out1.shape
        assert torch.allclose(dist_out1,
                              seq_out1,
                              rtol=ERROR_THRESHOLD,
                              atol=atol)
        assert dist_loss.shape == seq_loss.shape
        assert torch.allclose(dist_loss,
                              seq_loss,
                              rtol=ERROR_THRESHOLD,
                              atol=atol)
        for dist_grad, seq_grad in zip(dist_grads, seq_grads):
            assert dist_grad.shape == seq_grad.shape
            assert torch.allclose(dist_grad,
                                  seq_grad,
                                  rtol=ERROR_THRESHOLD,
                                  atol=atol)
        assert dist_out2.shape == seq_out2.shape
        assert torch.allclose(dist_out2,
                              seq_out2,
                              rtol=ERROR_THRESHOLD,
                              atol=atol)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_in_out_base.deactivate()
    P_in_out.deactivate()
    P_affine.deactivate()
Exemplo n.º 17
0
def test_broadcast_adjoint(barrier_fence_fixture, comm_split_fixture,
                           P_x_ranks, P_x_shape, P_y_ranks, P_y_shape,
                           x_global_shape, transpose_src):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.broadcast import Broadcast
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    # TODO #93: Change this to create a subtensor so we test when local tensors
    # have different shape.  Then, the output size will also be different, which
    # we will have to get from `y` itself.
    x_local_shape = np.asarray(x_global_shape)

    layer = Broadcast(P_x,
                      P_y,
                      transpose_src=transpose_src,
                      preserve_batch=False)
    layer = layer.to(device)

    x = zero_volume_tensor(device=device)
    if P_x.active:
        x = torch.randn(*x_local_shape, device=device)
    x.requires_grad = True

    dy = zero_volume_tensor(device=device)
    if P_y.active:
        # Adjoint Input
        dy = torch.randn(*x_local_shape, device=device)

    # y = F @ x
    y = layer(x)

    # dx = F* @ dy
    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
Exemplo n.º 18
0
def test_halo_exchange_adjoint(barrier_fence_fixture, comm_split_fixture,
                               P_x_ranks, P_x_shape, x_global_shape, dtype,
                               kernel_size, stride, padding, dilation,
                               MockKernelStyle):
    import numpy as np
    import torch
    import torch.nn.functional as F

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.halo_exchange import HaloExchange
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import distdl_padding_to_torch_padding
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    x_global_shape = np.asarray(x_global_shape)
    kernel_size = np.asarray(kernel_size)
    stride = np.asarray(stride)
    padding = np.asarray(padding)
    dilation = np.asarray(dilation)

    halo_shape = None
    recv_buffer_shape = None
    send_buffer_shape = None
    if P_x.active:
        mockup_layer = MockKernelStyle()
        exchange_info = mockup_layer._compute_exchange_info(
            x_global_shape, kernel_size, stride, padding, dilation, P_x.active,
            P_x.shape, P_x.index)
        halo_shape = exchange_info[0]
        recv_buffer_shape = exchange_info[1]
        send_buffer_shape = exchange_info[2]

    halo_layer = HaloExchange(P_x, halo_shape, recv_buffer_shape,
                              send_buffer_shape)
    halo_layer = halo_layer.to(device)

    x = zero_volume_tensor(x_global_shape[0], device=device)
    dy = zero_volume_tensor(x_global_shape[0], device=device)
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape)

        padding = distdl_padding_to_torch_padding(halo_shape)

        x = torch.randn(*x_local_shape, device=device).to(dtype)

        # Pad the input with the halo space.  We are only testing the behavior of
        # the halo exchange so the input must be padded before we can do anything.
        x = F.pad(x, pad=padding, mode="constant", value=0)

        # dy is also padded, but we wanted it to start with data inside it.
        dy = torch.randn(*x.shape, device=device).to(dtype)

    x.requires_grad = True

    # Halo Exchange (both fwd and adj) is in-place.  So, we copy the input
    # data and save the original for the adjoint test. Because it is in-place,
    # the clones themselves are modified.  This also prevents issues with us
    # in-place operations on leaf-nodes.

    x_clone = x.clone()
    dy_clone = dy.clone()

    # x_clone is be modified in place by halo_layer, but we assign y to
    # reference it for clarity.  y and x_clone are the same object.
    y = halo_layer(x_clone)

    # dy_clone is modified in place by halo_layer-adjoint, but we assign dx to
    # reference it for clarity.  dx and dy_clone are the same object.
    # dx is not in the grad field as you might expect because the operation is
    # in-place.
    y.backward(dy_clone)
    dx = dy_clone

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Exemplo n.º 19
0
def test_transpose_adjoint(barrier_fence_fixture, comm_split_fixture,
                           P_x_ranks, P_x_shape, P_y_ranks, P_y_shape,
                           x_global_shape, balanced):

    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    # The global tensor size is the same for x and y
    layer = DistributedTranspose(P_x, P_y, preserve_batch=False)
    layer = layer.to(device)

    # Forward Input
    x = zero_volume_tensor(device=device)
    if P_x.active:
        if balanced:
            x_local_shape = compute_subshape(P_x.shape, P_x.index,
                                             x_global_shape)
        else:
            quotient = np.atleast_1d(x_global_shape) // np.atleast_1d(
                P_x_shape)
            remainder = np.atleast_1d(x_global_shape) % np.atleast_1d(
                P_x_shape)
            loc = np.where(P_x.index == 0)
            x_local_shape = quotient.copy()
            x_local_shape[loc] += remainder[loc]

        x = torch.randn(*x_local_shape, device=device)

    x.requires_grad = True

    # Adjoint Input
    dy = zero_volume_tensor(device=device)
    if P_y.active:
        y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape)
        dy = torch.randn(*y_local_shape, device=device)

    # y = F @ x
    y = layer(x)

    # dx = F* @ dy
    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
Exemplo n.º 20
0
def test_all_sum_reduce_adjoint(barrier_fence_fixture,
                                comm_split_fixture,
                                P_x_ranks, P_x_shape,
                                x_global_shape,
                                axes_reduce):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.all_sum_reduce import AllSumReduce
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    # have different shape.  Then, the output size will also be different, which
    # we will have to get from `y` itself.
    x_local_shape = np.asarray(x_global_shape)

    layer = AllSumReduce(P_x, axes_reduce)
    layer = layer.to(device)

    x = zero_volume_tensor(device=device)
    if P_x.active:
        x = 10*torch.ones(*x_local_shape, device=device)
    x.requires_grad = True

    dy = zero_volume_tensor(device=device)
    if P_x.active:
        # Adjoint Input
        dy = 0.1*torch.ones(*x_local_shape, device=device)

    # y = F @ x
    y = layer(x)

    # dx = F* @ dy
    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    reduced_entry_value = 1
    for k in range(len(P_x_shape)):
        if k in axes_reduce:
            reduced_entry_value *= P_x_shape[k]

    assert(torch.all(y == 10*reduced_entry_value))
    assert(torch.all(dx == 0.1*reduced_entry_value))

    check_adjoint_test_tight(P_world, x, dx, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Exemplo n.º 21
0
def test_batch_norm_no_training(barrier_fence_fixture, P_x_ranks, P_x_shape,
                                input_shape, num_features, eps, momentum,
                                affine, track_running_stats,
                                comm_split_fixture):

    from distdl.backends.mpi.partition import MPIPartition

    device = torch.device('cuda' if use_cuda else 'cpu')

    torch.manual_seed(0)

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    num_dimensions = len(input_shape)
    P_in_out_base = P_world.create_partition_inclusive([0])
    P_in_out = P_in_out_base.create_cartesian_topology_partition(
        [1] * num_dimensions)
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    # Create the input
    if P_world.rank == 0:
        input_eval = torch.rand(input_shape,
                                dtype=torch.float32,
                                device=device)
    else:
        input_eval = zero_volume_tensor(device=device)

    # Create the sequential network
    if P_world.rank == 0:
        seq_net = torch.nn.Sequential(
            torch.nn.BatchNorm1d(num_features=num_features,
                                 eps=eps,
                                 momentum=momentum,
                                 affine=affine,
                                 track_running_stats=track_running_stats))
        seq_net = seq_net.to(device)
    else:
        seq_net = None

    # Evaluate sequential network
    if P_world.rank == 0:
        seq_net.eval()
        seq_out = seq_net(input_eval)
        seq_out = seq_out.detach().cpu()

    # Create distributed network
    dist_net = torch.nn.Sequential(
        distdl.nn.Repartition(P_in_out, P_x),
        distdl.nn.DistributedBatchNorm(
            P_x,
            num_features=num_features,
            eps=eps,
            momentum=momentum,
            affine=affine,
            track_running_stats=track_running_stats),
        distdl.nn.Repartition(P_x, P_in_out))
    dist_net = dist_net.to(device)

    # Evaluate distributed network
    dist_net.eval()
    dist_out = dist_net(input_eval)
    dist_out = dist_out.detach().cpu()

    # Compare the distributed and sequential networks
    if P_world.rank == 0:
        assert dist_out.shape == seq_out.shape

        # Set the absolute tolerance to ~sqrt(e_mach), or the default
        # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit
        # floats, not 32-bit floats as torch does.  Consequently, the default
        # torch atol is actually tighter than one can expect from two fp-equal
        # floating point numbers.  The NumPy default of 1e-8 is closer to
        # sqrt(e_mach) for 64-bit numbers.  So we set the 32-bit tolerance to
        # a little tighter than sqrt(1e-7), 1e-5.
        if seq_out.dtype == torch.float64:
            atol = 1e-8
        elif seq_out.dtype == torch.float32:
            atol = 1e-5
        else:
            # torch default
            atol = 1e-8
        assert torch.allclose(dist_out, seq_out, atol=atol)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_in_out_base.deactivate()
    P_in_out.deactivate()
Exemplo n.º 22
0
def test_repartition_dtype(barrier_fence_fixture, comm_split_fixture, dtype,
                           test_backward, P_x_ranks, P_x_shape, P_y_ranks,
                           P_y_shape, x_global_shape):

    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.repartition import Repartition
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    # The global tensor size is the same for x and y
    layer = Repartition(P_x, P_y, preserve_batch=False)
    layer = layer.to(device)

    # Forward Input
    x = zero_volume_tensor(dtype=dtype, device=device)
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape)
        x = 10 * torch.randn(*x_local_shape, device=device).to(dtype)

    x.requires_grad = test_backward
    # y = F @ x
    y = layer(x)
    if P_y.active:
        assert y.dtype == dtype

    if test_backward:
        # Adjoint Input
        dy = zero_volume_tensor(dtype=dtype, device=device)
        if P_y.active:
            y_local_shape = compute_subshape(P_y.shape, P_y.index,
                                             x_global_shape)
            dy = 10 * torch.randn(*y_local_shape, device=device).to(dtype)

        # dx = F* @ dy
        y.backward(dy)
        dx = x.grad
        if P_x.active:
            assert dx.dtype == dtype

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
Exemplo n.º 23
0
def test_repartition_identity(barrier_fence_fixture, comm_split_fixture,
                              P_x_ranks, P_x_shape, x_global_shape, balanced):

    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.repartition import Repartition
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y = P_x

    # The global tensor size is the same for x and y
    layer = Repartition(P_x, P_y, preserve_batch=False)
    layer = layer.to(device)

    # Forward Input
    x = zero_volume_tensor(device=device)
    if P_x.active:
        if balanced:
            x_local_shape = compute_subshape(P_x.shape, P_x.index,
                                             x_global_shape)
        else:
            quotient = np.atleast_1d(x_global_shape) // np.atleast_1d(
                P_x_shape)
            remainder = np.atleast_1d(x_global_shape) % np.atleast_1d(
                P_x_shape)
            loc = np.where(P_x.index == 0)
            x_local_shape = quotient.copy()
            x_local_shape[loc] += remainder[loc]

        x = torch.randn(*x_local_shape, device=device)

    x.requires_grad = True

    # Adjoint Input
    dy = zero_volume_tensor(device=device)
    if P_y.active:
        y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape)
        dy = torch.randn(*y_local_shape, device=device)

    # y = F @ x
    y = layer(x)

    # In the balanced case, this should be a true identity, so there should
    # be no communication performed, just self-copies.
    if balanced:
        for sl, sz, p in layer.P_x_to_y_overlaps:
            assert p == "self" or (sl, sz, p) == (None, None, None)
        for sl, sz, p in layer.P_y_to_x_overlaps:
            assert p == "self" or (sl, sz, p) == (None, None, None)

    # dx = F* @ dy
    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Exemplo n.º 24
0
def test_buffer_management_transpose_network(barrier_fence_fixture,
                                             comm_split_fixture):

    import numpy as np
    import torch

    import distdl
    from distdl.backends.mpi.buffer import MPIBufferManager
    from distdl.backends.mpi.partition import MPIPartition
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return

    buffer_manager = MPIBufferManager()
    P_world = MPIPartition(base_comm)

    # Create the partitions

    P_1_base = P_world.create_partition_inclusive([0])
    P_1 = P_1_base.create_cartesian_topology_partition([1, 1])

    P_3_base = P_world.create_partition_inclusive([1, 2, 3])
    P_3 = P_3_base.create_cartesian_topology_partition([1, 3])

    P_4_base = P_world.create_partition_inclusive([0, 1, 2, 3])
    P_4 = P_4_base.create_cartesian_topology_partition([1, 4])

    tr1 = distdl.nn.DistributedTranspose(P_1,
                                         P_4,
                                         buffer_manager=buffer_manager)
    tr2 = distdl.nn.DistributedTranspose(P_4,
                                         P_4,
                                         buffer_manager=buffer_manager)
    tr3 = distdl.nn.DistributedTranspose(P_4,
                                         P_3,
                                         buffer_manager=buffer_manager)
    tr4 = distdl.nn.DistributedTranspose(P_3,
                                         P_1,
                                         buffer_manager=buffer_manager)

    x = zero_volume_tensor(1)
    if P_1.active:
        x = torch.randn(1, 10)
    x.requires_grad = True

    # [0   1   2   3   4   5   6   7   8   9] to
    # [0   1   2] [3   4   5] [6   7] [8   9]
    x2 = tr1(x)
    n_buffers_by_rank = (3, 1, 1, 1)
    assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[
        P_world.rank]

    # [0   1   2] [3   4   5] [6   7] [8   9] to
    # [0   1   2] [3   4   5] [6   7] [8   9]
    x3 = tr2(x2)
    n_buffers_by_rank = (3, 1, 1, 1)
    assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[
        P_world.rank]

    #    [0   1   2] [3   4   5] [6   7] [8   9] to
    # [] [0   1   2   3] [4   5   6] [7   8   9]
    x4 = tr3(x3)
    n_buffers_by_rank = (3, 2, 2, 1)
    assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[
        P_world.rank]

    # [] [0   1   2   3] [4   5   6] [7   8   9] to
    #    [0   1   2   3   4   5   6   7   8   9]
    y = tr4(x4)
    n_buffers_by_rank = (3, 2, 2, 1)
    assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[
        P_world.rank]

    dy = zero_volume_tensor(1)
    if P_1.active:
        dy = torch.randn(1, 10)
    dy.requires_grad = True

    y.backward(dy)
    dx = x.grad

    # Through the backward call the buffer count do not change
    n_buffers_by_rank = (3, 2, 2, 1)
    assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[
        P_world.rank]

    # And adjointness is still preserved

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)

    P_world.deactivate()
    P_1_base.deactivate()
    P_1.deactivate()
    P_3_base.deactivate()
    P_3.deactivate()
    P_4_base.deactivate()
    P_4.deactivate()
Exemplo n.º 25
0
def test_conv_versus_pytorch(barrier_fence_fixture, comm_split_fixture,
                             P_x_ranks, P_x_shape, input_dimensions,
                             x_global_shape, kernel_size, padding, stride,
                             dilation, bias):

    import numpy as np
    import torch
    from torch.nn import Conv1d
    from torch.nn import Conv2d
    from torch.nn import Conv3d

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.conv_feature import DistributedFeatureConv1d
    from distdl.nn.conv_feature import DistributedFeatureConv2d
    from distdl.nn.conv_feature import DistributedFeatureConv3d
    from distdl.nn.repartition import Repartition
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)
    P_world._comm.Barrier()

    # Create the partitions
    P_0_base = P_world.create_partition_inclusive(np.arange(1))
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_0 = P_0_base.create_cartesian_topology_partition([1] * len(P_x_shape))
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    scatter_layer_x = Repartition(P_0, P_x)
    scatter_layer_x = scatter_layer_x.to(device)
    scatter_layer_y = Repartition(P_0, P_x)
    scatter_layer_y = scatter_layer_y.to(device)
    gather_layer_x = Repartition(P_x, P_0)
    gather_layer_x = gather_layer_x.to(device)
    gather_layer_y = Repartition(P_x, P_0)
    gather_layer_y = gather_layer_y.to(device)

    # Create the layers
    if input_dimensions == 1:
        dist_layer_type = DistributedFeatureConv1d
        seq_layer_type = Conv1d
    elif input_dimensions == 2:
        dist_layer_type = DistributedFeatureConv2d
        seq_layer_type = Conv2d
    elif input_dimensions == 3:
        dist_layer_type = DistributedFeatureConv3d
        seq_layer_type = Conv3d
    dist_layer = dist_layer_type(P_x,
                                 in_channels=x_global_shape[1],
                                 out_channels=10,
                                 kernel_size=kernel_size,
                                 padding=padding,
                                 stride=stride,
                                 dilation=dilation,
                                 bias=bias)
    dist_layer = dist_layer.to(device)
    if P_0.active:
        seq_layer = seq_layer_type(in_channels=x_global_shape[1],
                                   out_channels=10,
                                   kernel_size=kernel_size,
                                   padding=padding,
                                   stride=stride,
                                   dilation=dilation,
                                   bias=bias)
        # set the weights of both layers to be the same
        seq_layer = seq_layer.to(device)
        weight = torch.rand_like(seq_layer.weight, device=device)
        seq_layer.weight.data = weight
        dist_layer.weight.data = weight
        if bias:
            bias_weight = torch.rand_like(seq_layer.bias, device=device)
            seq_layer.bias.data = bias_weight
            dist_layer.bias.data = bias_weight

    # Forward Input
    x_ref = zero_volume_tensor(device=device)
    x_ref.requires_grad = True
    dy_ref = zero_volume_tensor(device=device)

    # Construct the inputs to the forward and backward functions as well as the
    # the outputs of the sequential layer
    if P_0.active:
        x_ref = torch.randn(*x_global_shape, device=device)
        x_ref.requires_grad = True
        y_ref = seq_layer(x_ref)
        y_global_shape_calc = y_ref.shape

        dy_ref = torch.randn(*y_global_shape_calc, device=device)

        y_ref.backward(dy_ref)
        dx_ref = x_ref.grad

    # Ensure that the scatter is not part of the computation we are testing
    with torch.no_grad():
        x = scatter_layer_x(x_ref.detach())
        dy = scatter_layer_y(dy_ref.detach())

    x.requires_grad = True

    y = dist_layer(x)
    y.backward(dy)
    dx = x.grad

    # Ensure that the gather is not part of the computation we are testing
    with torch.no_grad():
        dx_comp = gather_layer_x(dx.detach())
        y_comp = gather_layer_y(y.detach())

    if P_0.active:

        # Set the absolute tolerance to ~sqrt(e_mach), or the default
        # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit
        # floats, not 32-bit floats as torch does.  Consequently, the default
        # torch atol is actually tighter than one can expect from two fp-equal
        # floating point numbers.  The NumPy default of 1e-8 is closer to
        # sqrt(e_mach) for 64-bit numbers.  So we set the 32-bit tolerance to
        # a little tighter than sqrt(1e-7), 1e-5.
        if x_ref.dtype == torch.float64:
            atol = 1e-8
        elif x_ref.dtype == torch.float32:
            atol = 1e-5
        else:
            # torch default
            atol = 1e-8

        # Test the result of each entry independently
        assert torch.allclose(y_ref, y_comp, atol=atol)
        assert torch.allclose(dx_ref, dx_comp, atol=atol)

    P_world.deactivate()
    P_0_base.deactivate()
    P_0.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Exemplo n.º 26
0
def test_linear_adjoint_bias(barrier_fence_fixture,
                             comm_split_fixture,
                             P_x_ranks, P_x_shape,
                             P_y_ranks, P_y_shape,
                             P_w_ranks, P_w_shape,
                             x_global_shape,
                             y_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.linear import DistributedLinear
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    P_w_base = P_world.create_partition_inclusive(P_w_ranks)
    P_w = P_w_base.create_cartesian_topology_partition(P_w_shape)

    x_global_shape = np.asarray(x_global_shape)
    y_global_shape = np.asarray(y_global_shape)

    layer = DistributedLinear(P_x, P_y, P_w,
                              x_global_shape[1],
                              y_global_shape[1],
                              bias=True)
    layer = layer.to(device)

    x = zero_volume_tensor(x_global_shape[0], device=device)
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape,
                                         P_x.index,
                                         x_global_shape)
        # For this test, we are only testing to see if the adjoint works
        # correctly for the bias term.  But the adjoint test only works on the
        # Jacobian of the linear layer.  The Jacobian block for b is 0 for x and
        # W, so killing x makes the forward operator equal to its Jacobian and
        # we can test to see that adjoint is computed correctly.
        x = torch.zeros(*x_local_shape, device=device)
    x.requires_grad = True

    y = layer(x)

    dy = zero_volume_tensor(x_global_shape[0], device=device)
    if P_y.active:
        dy = torch.randn(*y.shape, device=device)

    y.backward(dy)

    b = zero_volume_tensor(device=device)
    db = zero_volume_tensor(device=device)
    if P_w.active and P_w.index[-1] == 0:
        b = layer.sublinear.bias.detach()
        db = layer.sublinear.bias.grad.detach()

    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, b, db, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
    P_w_base.deactivate()
    P_w.deactivate()