def test_simple_conv2d_adjoint_weight(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.conv_feature import DistributedFeatureConv2d from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) x_global_shape = np.asarray(x_global_shape) layer = DistributedFeatureConv2d(P_x, in_channels=x_global_shape[1], out_channels=10, kernel_size=[3, 3], bias=False) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.randn(*x_local_shape) x.requires_grad = True y = layer(x) dy = zero_volume_tensor(x_global_shape[0]) if P_x.active: dy = torch.randn(*y.shape) y.backward(dy) W = zero_volume_tensor() dW = zero_volume_tensor() if P_x.active: W = layer.weight.detach() dW = layer.weight.grad.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, W, dW, y, dy) P_world.deactivate() P_x_base.deactivate() P_x.deactivate()
def test_linear_adjoint_input(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, P_w_ranks, P_w_shape, x_global_shape, y_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.linear import DistributedLinear from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) P_w_base = P_world.create_partition_inclusive(P_w_ranks) P_w = P_w_base.create_cartesian_topology_partition(P_w_shape) x_global_shape = np.asarray(x_global_shape) y_global_shape = np.asarray(y_global_shape) layer = DistributedLinear(P_x, P_y, P_w, x_global_shape[1], y_global_shape[1], bias=False) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.Tensor(np.random.randn(*x_local_shape)) x.requires_grad = True y = layer(x) dy = zero_volume_tensor(x_global_shape[0]) if P_y.active: dy = torch.Tensor(np.random.randn(*y.shape)) y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy)
def test_transpose_adjoint(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, x_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.transpose import DistributedTranspose from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) # The global tensor size is the same for x and y layer = DistributedTranspose(P_x, P_y, preserve_batch=False) # Forward Input x = zero_volume_tensor() if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.Tensor(np.random.randn(*x_local_shape)) x.requires_grad = True # Adjoint Input dy = zero_volume_tensor() if P_y.active: y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape) dy = torch.Tensor(np.random.randn(*y_local_shape)) # y = F @ x y = layer(x) # dx = F* @ dy y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy)
def test_broadcast_adjoint(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, x_global_shape, transpose_src): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.broadcast import Broadcast from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) # TODO #93: Change this to create a subtensor so we test when local tensors # have different shape. Then, the output size will also be different, which # we will have to get from `y` itself. x_local_shape = np.asarray(x_global_shape) layer = Broadcast(P_x, P_y, transpose_src=transpose_src, preserve_batch=False) x = zero_volume_tensor() if P_x.active: x = torch.Tensor(np.random.randn(*x_local_shape)) x.requires_grad = True dy = zero_volume_tensor() if P_y.active: # Adjoint Input dy = torch.Tensor(np.random.randn(*x_local_shape)) # y = F @ x y = layer(x) # dx = F* @ dy y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy)
def test_padnd_adjoint(barrier_fence_fixture, comm_split_fixture, x_local_shape, padding): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.padnd import PadNd # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) x_local_shape = np.asarray(x_local_shape) padding = np.asarray(padding) padded_shape = [t + lpad + rpad for t, (lpad, rpad) in zip(x_local_shape, padding)] layer = PadNd(padding, value=0) x = torch.tensor(np.random.randn(*x_local_shape)) x.requires_grad = True dy = torch.tensor(np.random.randn(*padded_shape)) y = layer(x) y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy)
def test_general_conv1d_adjoint_bias(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, P_w_ranks, P_w_shape, x_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.conv_general import DistributedGeneralConv1d from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) P_w_base = P_world.create_partition_inclusive(P_w_ranks) P_w = P_w_base.create_cartesian_topology_partition(P_w_shape) x_global_shape = np.asarray(x_global_shape) layer = DistributedGeneralConv1d(P_x, P_y, P_w, in_channels=x_global_shape[1], out_channels=10, kernel_size=[3], bias=True) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.zeros(*x_local_shape) x.requires_grad = True y = layer(x) dy = zero_volume_tensor(x_global_shape[0]) if P_y.active: dy = torch.randn(*y.shape) y.backward(dy) b = zero_volume_tensor() db = zero_volume_tensor() if P_w.active and layer.stores_bias: b = layer.bias.detach() db = layer.bias.grad.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, b, db, y, dy) P_world.deactivate() P_x_base.deactivate() P_x.deactivate() P_y_base.deactivate() P_y.deactivate() P_w_base.deactivate() P_w.deactivate()
def test_linear_adjoint_bias(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, P_w_ranks, P_w_shape, x_global_shape, y_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.linear import DistributedLinear from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) P_w_base = P_world.create_partition_inclusive(P_w_ranks) P_w = P_w_base.create_cartesian_topology_partition(P_w_shape) x_global_shape = np.asarray(x_global_shape) y_global_shape = np.asarray(y_global_shape) layer = DistributedLinear(P_x, P_y, P_w, x_global_shape[1], y_global_shape[1], bias=True) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) # For this test, we are only testing to see if the adjoint works # correctly for the bias term. But the adjoint test only works on the # Jacobian of the linear layer. The Jacobian block for b is 0 for x and # W, so killing x makes the forward operator equal to its Jacobian and # we can test to see that adjoint is computed correctly. x = torch.zeros(*x_local_shape) x.requires_grad = True y = layer(x) dy = zero_volume_tensor(x_global_shape[0]) if P_y.active: dy = torch.Tensor(np.random.randn(*y.shape)) y.backward(dy) b = zero_volume_tensor() db = zero_volume_tensor() if P_w.active and P_w.index[-1] == 0: b = layer.sublinear.bias.detach() db = layer.sublinear.bias.grad.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, b, db, y, dy)
def test_halo_exchange_adjoint(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape, kernel_size, stride, padding, dilation, MockKernelStyle): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.halo_exchange import HaloExchange from distdl.nn.padnd import PadNd from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) x_global_shape = np.asarray(x_global_shape) kernel_size = np.asarray(kernel_size) stride = np.asarray(stride) padding = np.asarray(padding) dilation = np.asarray(dilation) halo_shape = None recv_buffer_shape = None send_buffer_shape = None if P_x.active: mockup_layer = MockKernelStyle() exchange_info = mockup_layer._compute_exchange_info(x_global_shape, kernel_size, stride, padding, dilation, P_x.active, P_x.shape, P_x.index) halo_shape = exchange_info[0] recv_buffer_shape = exchange_info[1] send_buffer_shape = exchange_info[2] pad_layer = PadNd(halo_shape, value=0) halo_layer = HaloExchange(P_x, halo_shape, recv_buffer_shape, send_buffer_shape) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.tensor(np.random.randn(*x_local_shape)) x = pad_layer.forward(x) x.requires_grad = True dy = zero_volume_tensor(x_global_shape[0]) if P_x.active: dy = torch.tensor(np.random.randn(*x.shape)) x_clone = x.clone() dy_clone = dy.clone() # x_clone is be modified in place by halo_layer, but we assign y to # reference it for clarity y = halo_layer(x_clone) # dy_clone is modified in place by halo_layer-adjoint, but we assign dx to # reference it for clarity y.backward(dy_clone) dx = dy_clone x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy)
def test_buffer_management_mixed_network(barrier_fence_fixture, comm_split_fixture): import numpy as np import torch import distdl from distdl.backends.mpi.buffer import MPIBufferManager from distdl.backends.mpi.partition import MPIPartition from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return buffer_manager = MPIBufferManager() P_world = MPIPartition(base_comm) # Create the partitions P_1_base = P_world.create_partition_inclusive([0]) P_1 = P_1_base.create_cartesian_topology_partition([1, 1, 1, 1]) P_22_base = P_world.create_partition_inclusive([0, 1, 2, 3]) P_22 = P_22_base.create_cartesian_topology_partition([1, 1, 2, 2]) tr1 = distdl.nn.DistributedTranspose(P_1, P_22, buffer_manager=buffer_manager) c1 = distdl.nn.DistributedConv2d(P_22, in_channels=1, out_channels=5, kernel_size=[3, 3], padding=[1, 1], bias=False, buffer_manager=buffer_manager) c2 = distdl.nn.DistributedConv2d(P_22, in_channels=5, out_channels=10, kernel_size=[3, 3], padding=[1, 1], bias=False, buffer_manager=buffer_manager) c3 = distdl.nn.DistributedConv2d(P_22, in_channels=10, out_channels=20, kernel_size=[3, 3], padding=[1, 1], bias=False, buffer_manager=buffer_manager) tr2 = distdl.nn.DistributedTranspose(P_22, P_1, buffer_manager=buffer_manager) x = zero_volume_tensor(1) if P_1.active: x = torch.randn(1, 1, 5, 5) x.requires_grad = True # [[00 01 02 03 04] [[00 01 02] [[03 04] # [10 11 12 13 14] [10 11 12] [13 14]] # [20 21 22 23 24] to [20 21 22]] [[23 24] # [30 31 32 33 34] [[30 31 32] [33 34] # [40 41 42 43 44]] [40 41 42]] [43 44]] x2 = tr1(x) n_buffers_by_rank = (3, 1, 1, 1) assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[ P_world.rank] # [[00 01 02] [[03 04] [[00 01 02] [[03 04] # [10 11 12] [13 14]] [10 11 12] [13 14]] # [20 21 22]] [[23 24] to [20 21 22]] [[23 24] # [[30 31 32] [33 34] [[30 31 32] [33 34] # [40 41 42]] [43 44]] [40 41 42]] [43 44]] x3 = c1(x2) n_buffers_by_rank = (4, 4, 4, 4) assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[ P_world.rank] # [[00 01 02] [[03 04] [[00 01 02] [[03 04] # [10 11 12] [13 14]] [10 11 12] [13 14]] # [20 21 22]] [[23 24] to [20 21 22]] [[23 24] # [[30 31 32] [33 34] [[30 31 32] [33 34] # [40 41 42]] [43 44]] [40 41 42]] [43 44]] x4 = c2(x3) n_buffers_by_rank = (4, 4, 4, 4) assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[ P_world.rank] # [[00 01 02] [[03 04] [[00 01 02] [[03 04] # [10 11 12] [13 14]] [10 11 12] [13 14]] # [20 21 22]] [[23 24] to [20 21 22]] [[23 24] # [[30 31 32] [33 34] [[30 31 32] [33 34] # [40 41 42]] [43 44]] [40 41 42]] [43 44]] x5 = c3(x4) n_buffers_by_rank = (4, 4, 4, 4) assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[ P_world.rank] # [[00 01 02] [[03 04] [[00 01 02 03 04] # [10 11 12] [13 14]] [10 11 12 13 14] # [20 21 22]] [[23 24] to [20 21 22 23 24] # [[30 31 32] [33 34] [30 31 32 33 34] # [40 41 42]] [43 44]] [40 41 42 43 44]] y = tr2(x5) n_buffers_by_rank = (4, 4, 4, 4) assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[ P_world.rank] dy = zero_volume_tensor(1) if P_1.active: dy = torch.randn(1, 20, 5, 5) dy.requires_grad = True y.backward(dy) dx = x.grad # Through the backward call the buffer count do not change n_buffers_by_rank = (4, 4, 4, 4) assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[ P_world.rank] # And adjointness is still preserved x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy) P_world.deactivate() P_1_base.deactivate() P_1.deactivate() P_22_base.deactivate() P_22.deactivate()
def test_buffer_management_transpose_network(barrier_fence_fixture, comm_split_fixture): import numpy as np import torch import distdl from distdl.backends.mpi.buffer import MPIBufferManager from distdl.backends.mpi.partition import MPIPartition from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return buffer_manager = MPIBufferManager() P_world = MPIPartition(base_comm) # Create the partitions P_1_base = P_world.create_partition_inclusive([0]) P_1 = P_1_base.create_cartesian_topology_partition([1, 1]) P_3_base = P_world.create_partition_inclusive([1, 2, 3]) P_3 = P_3_base.create_cartesian_topology_partition([1, 3]) P_4_base = P_world.create_partition_inclusive([0, 1, 2, 3]) P_4 = P_4_base.create_cartesian_topology_partition([1, 4]) tr1 = distdl.nn.DistributedTranspose(P_1, P_4, buffer_manager=buffer_manager) tr2 = distdl.nn.DistributedTranspose(P_4, P_4, buffer_manager=buffer_manager) tr3 = distdl.nn.DistributedTranspose(P_4, P_3, buffer_manager=buffer_manager) tr4 = distdl.nn.DistributedTranspose(P_3, P_1, buffer_manager=buffer_manager) x = zero_volume_tensor(1) if P_1.active: x = torch.randn(1, 10) x.requires_grad = True # [0 1 2 3 4 5 6 7 8 9] to # [0 1 2] [3 4 5] [6 7] [8 9] x2 = tr1(x) n_buffers_by_rank = (3, 1, 1, 1) assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[ P_world.rank] # [0 1 2] [3 4 5] [6 7] [8 9] to # [0 1 2] [3 4 5] [6 7] [8 9] x3 = tr2(x2) n_buffers_by_rank = (3, 1, 1, 1) assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[ P_world.rank] # [0 1 2] [3 4 5] [6 7] [8 9] to # [] [0 1 2 3] [4 5 6] [7 8 9] x4 = tr3(x3) n_buffers_by_rank = (3, 2, 2, 1) assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[ P_world.rank] # [] [0 1 2 3] [4 5 6] [7 8 9] to # [0 1 2 3 4 5 6 7 8 9] y = tr4(x4) n_buffers_by_rank = (3, 2, 2, 1) assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[ P_world.rank] dy = zero_volume_tensor(1) if P_1.active: dy = torch.randn(1, 10) dy.requires_grad = True y.backward(dy) dx = x.grad # Through the backward call the buffer count do not change n_buffers_by_rank = (3, 2, 2, 1) assert len(buffer_manager.buffers_map[np.float32]) == n_buffers_by_rank[ P_world.rank] # And adjointness is still preserved x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy) P_world.deactivate() P_1_base.deactivate() P_1.deactivate() P_3_base.deactivate() P_3.deactivate() P_4_base.deactivate() P_4.deactivate()
def test_repartition_identity(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape, balanced): import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.repartition import Repartition from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y = P_x # The global tensor size is the same for x and y layer = Repartition(P_x, P_y, preserve_batch=False) layer = layer.to(device) # Forward Input x = zero_volume_tensor(device=device) if P_x.active: if balanced: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) else: quotient = np.atleast_1d(x_global_shape) // np.atleast_1d( P_x_shape) remainder = np.atleast_1d(x_global_shape) % np.atleast_1d( P_x_shape) loc = np.where(P_x.index == 0) x_local_shape = quotient.copy() x_local_shape[loc] += remainder[loc] x = torch.randn(*x_local_shape, device=device) x.requires_grad = True # Adjoint Input dy = zero_volume_tensor(device=device) if P_y.active: y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape) dy = torch.randn(*y_local_shape, device=device) # y = F @ x y = layer(x) # In the balanced case, this should be a true identity, so there should # be no communication performed, just self-copies. if balanced: for sl, sz, p in layer.P_x_to_y_overlaps: assert p == "self" or (sl, sz, p) == (None, None, None) for sl, sz, p in layer.P_y_to_x_overlaps: assert p == "self" or (sl, sz, p) == (None, None, None) # dx = F* @ dy y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy) P_world.deactivate() P_x_base.deactivate() P_x.deactivate()
def test_halo_exchange_adjoint(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape, dtype, kernel_size, stride, padding, dilation, MockKernelStyle): import numpy as np import torch import torch.nn.functional as F from distdl.backends.mpi.partition import MPIPartition from distdl.nn.halo_exchange import HaloExchange from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import distdl_padding_to_torch_padding from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) x_global_shape = np.asarray(x_global_shape) kernel_size = np.asarray(kernel_size) stride = np.asarray(stride) padding = np.asarray(padding) dilation = np.asarray(dilation) halo_shape = None recv_buffer_shape = None send_buffer_shape = None if P_x.active: mockup_layer = MockKernelStyle() exchange_info = mockup_layer._compute_exchange_info( x_global_shape, kernel_size, stride, padding, dilation, P_x.active, P_x.shape, P_x.index) halo_shape = exchange_info[0] recv_buffer_shape = exchange_info[1] send_buffer_shape = exchange_info[2] halo_layer = HaloExchange(P_x, halo_shape, recv_buffer_shape, send_buffer_shape) halo_layer = halo_layer.to(device) x = zero_volume_tensor(x_global_shape[0], device=device) dy = zero_volume_tensor(x_global_shape[0], device=device) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) padding = distdl_padding_to_torch_padding(halo_shape) x = torch.randn(*x_local_shape, device=device).to(dtype) # Pad the input with the halo space. We are only testing the behavior of # the halo exchange so the input must be padded before we can do anything. x = F.pad(x, pad=padding, mode="constant", value=0) # dy is also padded, but we wanted it to start with data inside it. dy = torch.randn(*x.shape, device=device).to(dtype) x.requires_grad = True # Halo Exchange (both fwd and adj) is in-place. So, we copy the input # data and save the original for the adjoint test. Because it is in-place, # the clones themselves are modified. This also prevents issues with us # in-place operations on leaf-nodes. x_clone = x.clone() dy_clone = dy.clone() # x_clone is be modified in place by halo_layer, but we assign y to # reference it for clarity. y and x_clone are the same object. y = halo_layer(x_clone) # dy_clone is modified in place by halo_layer-adjoint, but we assign dx to # reference it for clarity. dx and dy_clone are the same object. # dx is not in the grad field as you might expect because the operation is # in-place. y.backward(dy_clone) dx = dy_clone x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy) P_world.deactivate() P_x_base.deactivate() P_x.deactivate()
def test_transpose_adjoint(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, x_global_shape, balanced): import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.transpose import DistributedTranspose from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) # The global tensor size is the same for x and y layer = DistributedTranspose(P_x, P_y, preserve_batch=False) layer = layer.to(device) # Forward Input x = zero_volume_tensor(device=device) if P_x.active: if balanced: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) else: quotient = np.atleast_1d(x_global_shape) // np.atleast_1d( P_x_shape) remainder = np.atleast_1d(x_global_shape) % np.atleast_1d( P_x_shape) loc = np.where(P_x.index == 0) x_local_shape = quotient.copy() x_local_shape[loc] += remainder[loc] x = torch.randn(*x_local_shape, device=device) x.requires_grad = True # Adjoint Input dy = zero_volume_tensor(device=device) if P_y.active: y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape) dy = torch.randn(*y_local_shape, device=device) # y = F @ x y = layer(x) # dx = F* @ dy y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy) P_world.deactivate() P_x_base.deactivate() P_x.deactivate() P_y_base.deactivate() P_y.deactivate()
def test_all_sum_reduce_adjoint(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape, axes_reduce): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.all_sum_reduce import AllSumReduce from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) # have different shape. Then, the output size will also be different, which # we will have to get from `y` itself. x_local_shape = np.asarray(x_global_shape) layer = AllSumReduce(P_x, axes_reduce) layer = layer.to(device) x = zero_volume_tensor(device=device) if P_x.active: x = 10*torch.ones(*x_local_shape, device=device) x.requires_grad = True dy = zero_volume_tensor(device=device) if P_x.active: # Adjoint Input dy = 0.1*torch.ones(*x_local_shape, device=device) # y = F @ x y = layer(x) # dx = F* @ dy y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() reduced_entry_value = 1 for k in range(len(P_x_shape)): if k in axes_reduce: reduced_entry_value *= P_x_shape[k] assert(torch.all(y == 10*reduced_entry_value)) assert(torch.all(dx == 0.1*reduced_entry_value)) check_adjoint_test_tight(P_world, x, dx, y, dy) P_world.deactivate() P_x_base.deactivate() P_x.deactivate()