示例#1
0
def test_excepts_mismatched_nondivisible_tensor(barrier_fence_fixture,
                                                comm_split_fixture):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # A tensor with size 1 in a dimension cannot be partitioned in that
    # dimension.  (See last dimension of output and tensor.)
    in_shape = (1, 4, 1, 1)
    out_shape = (1, 1, 1, 2)
    x_global_shape = np.array([1, 16, 5, 1])

    in_size = np.prod(in_shape)
    out_size = np.prod(out_shape)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size))
    P_x = P_x_base.create_cartesian_topology_partition(in_shape)

    P_y_base = P_world.create_partition_inclusive(
        np.arange(P_world.size - out_size, P_world.size))
    P_y = P_y_base.create_cartesian_topology_partition(out_shape)

    with pytest.raises(ValueError) as e_info:  # noqa: F841
        layer = DistributedTranspose(P_x, P_y)
        layer = layer.to(device)

        # Forward Input
        x = zero_volume_tensor(device=device)
        if P_x.active:
            x_local_shape = compute_subshape(P_x.shape, P_x.index,
                                             x_global_shape)
            x = torch.randn(*x_local_shape, device=device)
        x.requires_grad = True

        layer(x)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
示例#2
0
def test_excepts_mismatched_partitions(barrier_fence_fixture,
                                       comm_split_fixture):

    import numpy as np

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    in_shape = (1, 4, 1, 1)
    out_shape = (1, 2)

    in_size = np.prod(in_shape)
    out_size = np.prod(out_shape)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size))
    P_x = P_x_base.create_cartesian_topology_partition(in_shape)

    P_y_base = P_world.create_partition_inclusive(np.arange(P_world.size-out_size, P_world.size))
    P_y = P_y_base.create_cartesian_topology_partition(out_shape)

    with pytest.raises(ValueError) as e_info:  # noqa: F841
        DistributedTranspose(P_x, P_y)
示例#3
0
def test_transpose_adjoint(barrier_fence_fixture,
                           comm_split_fixture,
                           P_x_ranks, P_x_shape,
                           P_y_ranks, P_y_shape,
                           x_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    # The global tensor size is the same for x and y
    layer = DistributedTranspose(P_x, P_y, preserve_batch=False)

    # Forward Input
    x = zero_volume_tensor()
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape,
                                         P_x.index,
                                         x_global_shape)
        x = torch.Tensor(np.random.randn(*x_local_shape))
    x.requires_grad = True

    # Adjoint Input
    dy = zero_volume_tensor()
    if P_y.active:
        y_local_shape = compute_subshape(P_y.shape,
                                         P_y.index,
                                         x_global_shape)
        dy = torch.Tensor(np.random.randn(*y_local_shape))

    # y = F @ x
    y = layer(x)

    # dx = F* @ dy
    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)
示例#4
0
def test_excepts_mismatched_output_partition_tensor(barrier_fence_fixture,
                                                    comm_split_fixture):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Output partition rank must match tensor rank
    in_shape = (4, 1, 1)
    out_shape = (1, 1, 1, 2)
    x_global_shape = np.array([16, 5, 5])

    in_size = np.prod(in_shape)
    out_size = np.prod(out_shape)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size))
    P_x = P_x_base.create_cartesian_topology_partition(in_shape)

    P_y_base = P_world.create_partition_inclusive(np.arange(P_world.size-out_size, P_world.size))
    P_y = P_y_base.create_cartesian_topology_partition(out_shape)

    with pytest.raises(ValueError) as e_info:  # noqa: F841
        layer = DistributedTranspose(P_x, P_y)

        # Forward Input
        x = zero_volume_tensor()
        if P_x.active:
            x_local_shape = compute_subshape(P_x.shape,
                                             P_x.index,
                                             x_global_shape)
            x = torch.Tensor(np.random.randn(*x_local_shape))
        x.requires_grad = True

        layer(x)
示例#5
0
def test_conv_versus_pytorch(barrier_fence_fixture, comm_split_fixture,
                             P_x_ranks, P_x_shape, input_dimensions,
                             x_global_shape, kernel_size, padding, stride,
                             dilation, bias):

    import numpy as np
    import torch
    from torch.nn import Conv1d
    from torch.nn import Conv2d
    from torch.nn import Conv3d

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.conv_feature import DistributedFeatureConv1d
    from distdl.nn.conv_feature import DistributedFeatureConv2d
    from distdl.nn.conv_feature import DistributedFeatureConv3d
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)
    P_world._comm.Barrier()

    # Create the partitions
    P_0_base = P_world.create_partition_inclusive(np.arange(1))
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_0 = P_0_base.create_cartesian_topology_partition([1] * len(P_x_shape))
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    scatter_layer_x = DistributedTranspose(P_0, P_x)
    scatter_layer_x = scatter_layer_x.to(device)
    scatter_layer_y = DistributedTranspose(P_0, P_x)
    scatter_layer_y = scatter_layer_y.to(device)
    gather_layer_x = DistributedTranspose(P_x, P_0)
    gather_layer_x = gather_layer_x.to(device)
    gather_layer_y = DistributedTranspose(P_x, P_0)
    gather_layer_y = gather_layer_y.to(device)

    # Create the layers
    if input_dimensions == 1:
        dist_layer_type = DistributedFeatureConv1d
        seq_layer_type = Conv1d
    elif input_dimensions == 2:
        dist_layer_type = DistributedFeatureConv2d
        seq_layer_type = Conv2d
    elif input_dimensions == 3:
        dist_layer_type = DistributedFeatureConv3d
        seq_layer_type = Conv3d
    dist_layer = dist_layer_type(P_x,
                                 in_channels=x_global_shape[1],
                                 out_channels=10,
                                 kernel_size=kernel_size,
                                 padding=padding,
                                 stride=stride,
                                 dilation=dilation,
                                 bias=bias)
    dist_layer = dist_layer.to(device)
    if P_0.active:
        seq_layer = seq_layer_type(in_channels=x_global_shape[1],
                                   out_channels=10,
                                   kernel_size=kernel_size,
                                   padding=padding,
                                   stride=stride,
                                   dilation=dilation,
                                   bias=bias)
        # set the weights of both layers to be the same
        seq_layer = seq_layer.to(device)
        weight = torch.rand_like(seq_layer.weight, device=device)
        seq_layer.weight.data = weight
        dist_layer.weight.data = weight
        if bias:
            bias_weight = torch.rand_like(seq_layer.bias, device=device)
            seq_layer.bias.data = bias_weight
            dist_layer.bias.data = bias_weight

    # Forward Input
    x_ref = zero_volume_tensor(device=device)
    x_ref.requires_grad = True
    dy_ref = zero_volume_tensor(device=device)

    # Construct the inputs to the forward and backward functions as well as the
    # the outputs of the sequential layer
    if P_0.active:
        x_ref = torch.randn(*x_global_shape, device=device)
        x_ref.requires_grad = True
        y_ref = seq_layer(x_ref)
        y_global_shape_calc = y_ref.shape

        dy_ref = torch.randn(*y_global_shape_calc, device=device)

        y_ref.backward(dy_ref)
        dx_ref = x_ref.grad

    # Ensure that the scatter is not part of the computation we are testing
    with torch.no_grad():
        x = scatter_layer_x(x_ref.detach())
        dy = scatter_layer_y(dy_ref.detach())

    x.requires_grad = True

    y = dist_layer(x)
    y.backward(dy)
    dx = x.grad

    # Ensure that the gather is not part of the computation we are testing
    with torch.no_grad():
        dx_comp = gather_layer_x(dx.detach())
        y_comp = gather_layer_y(y.detach())

    if P_0.active:

        # Set the absolute tolerance to ~sqrt(e_mach), or the default
        # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit
        # floats, not 32-bit floats as torch does.  Consequently, the default
        # torch atol is actually tighter than one can expect from two fp-equal
        # floating point numbers.  The NumPy default of 1e-8 is closer to
        # sqrt(e_mach) for 64-bit numbers.  So we set the 32-bit tolerance to
        # a little tighter than sqrt(1e-7), 1e-5.
        if x_ref.dtype == torch.float64:
            atol = 1e-8
        elif x_ref.dtype == torch.float32:
            atol = 1e-5
        else:
            # torch default
            atol = 1e-8

        # Test the result of each entry independently
        assert torch.allclose(y_ref, y_comp, atol=atol)
        assert torch.allclose(dx_ref, dx_comp, atol=atol)

    P_world.deactivate()
    P_0_base.deactivate()
    P_0.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
示例#6
0
def test_matches_sequential(barrier_fence_fixture, comm_split_fixture,
                            P_x_ranks, P_x_shape, input_dimensions,
                            x_global_shape, kernel_size, padding, stride,
                            dilation, layer_type):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)
    P_world._comm.Barrier()

    # Create the partitions
    P_0_base = P_world.create_partition_inclusive(np.arange(1))
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_0 = P_0_base.create_cartesian_topology_partition([1] * len(P_x_shape))
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    scatter_layer_x = DistributedTranspose(P_0, P_x).to(device)
    scatter_layer_y = DistributedTranspose(P_0, P_x).to(device)
    gather_layer_x = DistributedTranspose(P_x, P_0).to(device)
    gather_layer_y = DistributedTranspose(P_x, P_0).to(device)

    # Create the layers
    if input_dimensions == 1:
        if layer_type == 'max':
            from torch.nn import MaxPool1d as SequentialPoolType

            from distdl.nn import DistributedMaxPool1d as DistributedPoolType
        else:
            from torch.nn import AvgPool1d as SequentialPoolType

            from distdl.nn import DistributedAvgPool1d as DistributedPoolType
    elif input_dimensions == 2:
        if layer_type == 'max':
            from torch.nn import MaxPool2d as SequentialPoolType

            from distdl.nn import DistributedMaxPool2d as DistributedPoolType
        else:
            from torch.nn import AvgPool2d as SequentialPoolType

            from distdl.nn import DistributedAvgPool2d as DistributedPoolType
    elif input_dimensions == 3:
        if layer_type == 'max':
            from torch.nn import MaxPool3d as SequentialPoolType

            from distdl.nn import DistributedMaxPool3d as DistributedPoolType
        else:
            from torch.nn import AvgPool3d as SequentialPoolType

            from distdl.nn import DistributedAvgPool3d as DistributedPoolType

    # PyTorch AvgPool doesn't support dilation, so skip the test if the combination comes up
    dilation_is_default = dilation == 1 or all(x == 1 for x in dilation)
    if layer_type == 'avg' and not dilation_is_default:
        return

    layer_kwargs = {
        'kernel_size': kernel_size,
        'padding': padding,
        'stride': stride
    }

    # Only max pool layers support dilation
    if layer_type == 'max':
        layer_kwargs['dilation'] = dilation

    dist_layer = DistributedPoolType(P_x, **layer_kwargs).to(device)
    if P_0.active:
        seq_layer = SequentialPoolType(**layer_kwargs).to(device)

    # Forward Input
    x_ref = zero_volume_tensor(device=device)
    x_ref.requires_grad = True
    dy_ref = zero_volume_tensor(device=device)

    # Construct the inputs to the forward and backward functions as well as the
    # the outputs of the sequential layer
    if P_0.active:
        x_ref = torch.randn(*x_global_shape, device=device)
        x_ref.requires_grad = True
        y_ref = seq_layer(x_ref)
        y_global_shape_calc = y_ref.shape

        dy_ref = torch.randn(*y_global_shape_calc, device=device)

        y_ref.backward(dy_ref)
        dx_ref = x_ref.grad

    # Ensure that the scatter is not part of the computation we are testing
    with torch.no_grad():
        x = scatter_layer_x(x_ref.detach())
        dy = scatter_layer_y(dy_ref.detach())

    x.requires_grad = True

    y = dist_layer(x)
    y.backward(dy)
    dx = x.grad

    # Ensure that the gather is not part of the computation we are testing
    with torch.no_grad():
        dx_comp = gather_layer_x(dx.detach())
        y_comp = gather_layer_y(y.detach())

    if P_0.active:

        # Set the absolute tolerance to ~sqrt(e_mach), or the default
        # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit
        # floats, not 32-bit floats as torch does.  Consequently, the default
        # torch atol is actually tighter than one can expect from two fp-equal
        # floating point numbers.  The NumPy default of 1e-8 is closer to
        # sqrt(e_mach) for 64-bit numbers.  So we set the 32-bit tolerance to
        # a little tighter than sqrt(1e-7), 1e-5.
        if x_ref.dtype == torch.float64:
            atol = 1e-8
        elif x_ref.dtype == torch.float32:
            atol = 1e-5
        else:
            # torch default
            atol = 1e-8

        assert torch.allclose(y_ref, y_comp, atol=atol)
        assert torch.allclose(dx_ref, dx_comp, atol=atol)

    P_world.deactivate()
    P_0_base.deactivate()
    P_0.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
示例#7
0
def test_upsample_matches_sequential(barrier_fence_fixture, comm_split_fixture,
                                     P_x_ranks, P_x_shape, x_global_shape,
                                     scale_factor, y_global_shape, mode,
                                     align_corners, use_size):

    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.nn.upsampling import DistributedUpsample
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return

    # Test align_corners only in the linear case, otherwise ignore it.
    if mode != "linear" and align_corners:
        return

    torch_mode_map = {3: "linear", 4: "bilinear", 5: "trilinear"}
    torch_mode = mode
    if mode == "linear":
        torch_mode = torch_mode_map[len(x_global_shape)]

    torch_align_corners = align_corners
    if mode == "nearest":
        torch_align_corners = None

    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_0_base = P_world.create_partition_inclusive([0])
    P_0 = P_0_base.create_cartesian_topology_partition([1] * len(P_x_shape))

    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    scatter_layer_x = DistributedTranspose(P_0, P_x)
    scatter_layer_y = DistributedTranspose(P_0, P_x)
    gather_layer_x = DistributedTranspose(P_x, P_0)
    gather_layer_y = DistributedTranspose(P_x, P_0)

    if use_size:
        dist_layer = DistributedUpsample(P_x,
                                         size=y_global_shape,
                                         mode=mode,
                                         align_corners=align_corners)
        if P_0.active:
            seq_layer = torch.nn.Upsample(size=y_global_shape[2:],
                                          mode=torch_mode,
                                          align_corners=torch_align_corners)
    else:
        dist_layer = DistributedUpsample(P_x,
                                         scale_factor=scale_factor,
                                         mode=mode,
                                         align_corners=align_corners)
        if P_0.active:
            seq_layer = torch.nn.Upsample(scale_factor=scale_factor,
                                          mode=torch_mode,
                                          align_corners=torch_align_corners)

    # Forward Input
    x_ref = zero_volume_tensor()
    x_ref.requires_grad = True
    dy_ref = zero_volume_tensor()

    # Construct the inputs to the forward and backward functions as well as the
    # the outputs of the sequential layer
    if P_0.active:
        x_ref = torch.randn(*x_global_shape)
        x_ref.requires_grad = True
        y_ref = seq_layer(x_ref)
        y_global_shape_calc = y_ref.shape

        dy_ref = torch.randn(*y_global_shape_calc)

        y_ref.backward(dy_ref)
        dx_ref = x_ref.grad

    # Ensure that the scatter is not part of the computation we are testing
    with torch.no_grad():
        x = scatter_layer_x(x_ref.detach())
        dy = scatter_layer_y(dy_ref.detach())

    x.requires_grad = True

    # Because  there is no guarantee that any padding is needed, in this test,
    # the input x may pass directly to the Halo layer without going through
    # the padding process.  As the halo layer is in-place, that would mean a leaf-node
    # variable is modified in-place, which PyTorch does not allow.
    #
    # Thus, we have to clone it to make the input not a leaf-node.
    x_clone = x.clone()
    y = dist_layer(x_clone)
    y.backward(dy)
    dx = x.grad

    # Ensure that the gather is not part of the computation we are testing
    with torch.no_grad():
        dx_comp = gather_layer_x(dx.detach())
        y_comp = gather_layer_y(y.detach())

    if P_0.active:

        # Set the absolute tolerance to ~sqrt(e_mach), or the default
        # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit
        # floats, not 32-bit floats as torch does.  Consequently, the default
        # torch atol is actually tighter than one can expect from two fp-equal
        # floating point numbers.  The NumPy default of 1e-8 is closer to
        # sqrt(e_mach) for 64-bit numbers.  So we set the 32-bit tolerance to
        # a little tighter than sqrt(1e-7), 1e-5.
        if x_ref.dtype == torch.float64:
            atol = 1e-8
        elif x_ref.dtype == torch.float32:
            atol = 1e-5
        else:
            # torch default
            atol = 1e-8

        # Test the result of each entry independently
        assert torch.allclose(y_ref, y_comp, atol=atol)
        assert torch.allclose(dx_ref, dx_comp, atol=atol)

    P_world.deactivate()
    P_0_base.deactivate()
    P_0.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
示例#8
0
P_x_base = P_world.create_partition_inclusive(in_workers)
P_x = P_x_base.create_cartesian_topology_partition(in_shape)

# Create the output partition (using the last 4 workers)
out_shape = (2, 2)
out_size = np.prod(out_shape)
out_workers = np.arange(P_world.size - out_size, P_world.size)

P_y_base = P_world.create_partition_inclusive(out_workers)
P_y = P_y_base.create_cartesian_topology_partition(out_shape)

# This global tensor shape is among the smallest useful shapes for an example
x_global_shape = np.array([7, 5])

# Create the transpose layer
layer = DistributedTranspose(P_x, P_y, preserve_batch=False)

# Setup the input tensor.  Any worker in P_x will generate its part of the
# input tensor.  Any worker not in P_x will have a zero-volume tensor.
#
# Input tensor will be (on a 1 x 1 partition):
# [ [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ] ]
x = zero_volume_tensor()
if P_x.active:
    x_local_shape = slicing.compute_subshape(P_x.shape, P_x.index,
示例#9
0
def test_transpose_identity(barrier_fence_fixture, comm_split_fixture,
                            P_x_ranks, P_x_shape, x_global_shape, balanced):

    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y = P_x

    # The global tensor size is the same for x and y
    layer = DistributedTranspose(P_x, P_y, preserve_batch=False)
    layer = layer.to(device)

    # Forward Input
    x = zero_volume_tensor(device=device)
    if P_x.active:
        if balanced:
            x_local_shape = compute_subshape(P_x.shape, P_x.index,
                                             x_global_shape)
        else:
            quotient = np.atleast_1d(x_global_shape) // np.atleast_1d(
                P_x_shape)
            remainder = np.atleast_1d(x_global_shape) % np.atleast_1d(
                P_x_shape)
            loc = np.where(P_x.index == 0)
            x_local_shape = quotient.copy()
            x_local_shape[loc] += remainder[loc]

        x = torch.randn(*x_local_shape, device=device)

    x.requires_grad = True

    # Adjoint Input
    dy = zero_volume_tensor(device=device)
    if P_y.active:
        y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape)
        dy = torch.randn(*y_local_shape, device=device)

    # y = F @ x
    y = layer(x)

    # In the balanced case, this should be a true identity, so there should
    # be no communication performed, just self-copies.
    if balanced:
        for sl, sz, p in layer.P_x_to_y_overlaps:
            assert p == "self" or (sl, sz, p) == (None, None, None)
        for sl, sz, p in layer.P_y_to_x_overlaps:
            assert p == "self" or (sl, sz, p) == (None, None, None)

    # dx = F* @ dy
    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
示例#10
0
def test_transpose_dtype(barrier_fence_fixture, comm_split_fixture, dtype,
                         test_backward, P_x_ranks, P_x_shape, P_y_ranks,
                         P_y_shape, x_global_shape):

    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    # The global tensor size is the same for x and y
    layer = DistributedTranspose(P_x, P_y, preserve_batch=False)
    layer = layer.to(device)

    # Forward Input
    x = zero_volume_tensor(dtype=dtype, device=device)
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape)
        x = 10 * torch.randn(*x_local_shape, device=device).to(dtype)

    x.requires_grad = test_backward
    # y = F @ x
    y = layer(x)
    if P_y.active:
        assert y.dtype == dtype

    if test_backward:
        # Adjoint Input
        dy = zero_volume_tensor(dtype=dtype, device=device)
        if P_y.active:
            y_local_shape = compute_subshape(P_y.shape, P_y.index,
                                             x_global_shape)
            dy = 10 * torch.randn(*y_local_shape, device=device).to(dtype)

        # dx = F* @ dy
        y.backward(dy)
        dx = x.grad
        if P_x.active:
            assert dx.dtype == dtype

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()