def forward(ctx, input, P_send, P_recv, preserve_batch, input_tensor_structure, output_tensor_structure, dtype): ctx.P_send = P_send ctx.P_recv = P_recv ctx.preserve_batch = preserve_batch ctx.input_tensor_structure = input_tensor_structure ctx.output_tensor_structure = output_tensor_structure ctx.dtype = dtype input_tensor_shape = input_tensor_structure[2] output_requires_grad = output_tensor_structure[0] output_tensor_shape = output_tensor_structure[2] # This allows all ranks to use the same exit path, so that we can be # sure that all requests have cleared. if preserve_batch: output = zero_volume_tensor(input.shape[0]) else: output = zero_volume_tensor() requests = [] # By design, the roots are always 0 in the cross-communicators # If I receive data (either from a remote worker or just from myself) # I need to reduce that data. If I send and receive to myself, this # is OK, as the reduction accounts for the copy, unlike the broadcast # below. if P_send.active: reduced_data_send = np.zeros(input_tensor_shape, dtype=dtype) input_numpy = input.detach().numpy() req = P_send.comm.Ireduce(input_numpy, reduced_data_send, root=0, op=MPI.SUM) requests.append(req) # If I sent data in the forward, I have to receive it here. mpi4py # does not allow aliasing of the input, so we have to make a copy of # nothing, unfortunately. if P_send != P_recv and P_recv.active: reduced_data_recv = np.zeros(output_tensor_shape, dtype=dtype) req = P_recv.comm.Ireduce(reduced_data_recv.copy(), reduced_data_recv, root=0, op=MPI.SUM) requests.append(req) MPI.Request.Waitall(requests) # If we had to receive data, we need to tensorify it. if P_recv.active: if P_send == P_recv: output = torch.tensor(reduced_data_send, requires_grad=output_requires_grad) else: output = torch.tensor(reduced_data_recv, requires_grad=output_requires_grad) return output
def test_linear_adjoint_input(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, P_w_ranks, P_w_shape, x_global_shape, y_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.linear import DistributedLinear from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) P_w_base = P_world.create_partition_inclusive(P_w_ranks) P_w = P_w_base.create_cartesian_topology_partition(P_w_shape) x_global_shape = np.asarray(x_global_shape) y_global_shape = np.asarray(y_global_shape) layer = DistributedLinear(P_x, P_y, P_w, x_global_shape[1], y_global_shape[1], bias=False) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.Tensor(np.random.randn(*x_local_shape)) x.requires_grad = True y = layer(x) dy = zero_volume_tensor(x_global_shape[0]) if P_y.active: dy = torch.Tensor(np.random.randn(*y.shape)) y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy)
def test_simple_conv2d_adjoint_weight(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.conv_feature import DistributedFeatureConv2d from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) x_global_shape = np.asarray(x_global_shape) layer = DistributedFeatureConv2d(P_x, in_channels=x_global_shape[1], out_channels=10, kernel_size=[3, 3], bias=False) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.randn(*x_local_shape) x.requires_grad = True y = layer(x) dy = zero_volume_tensor(x_global_shape[0]) if P_x.active: dy = torch.randn(*y.shape) y.backward(dy) W = zero_volume_tensor() dW = zero_volume_tensor() if P_x.active: W = layer.weight.detach() dW = layer.weight.grad.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, W, dW, y, dy) P_world.deactivate() P_x_base.deactivate() P_x.deactivate()
def test_transpose_adjoint(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, x_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.transpose import DistributedTranspose from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) # The global tensor size is the same for x and y layer = DistributedTranspose(P_x, P_y, preserve_batch=False) # Forward Input x = zero_volume_tensor() if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.Tensor(np.random.randn(*x_local_shape)) x.requires_grad = True # Adjoint Input dy = zero_volume_tensor() if P_y.active: y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape) dy = torch.Tensor(np.random.randn(*y_local_shape)) # y = F @ x y = layer(x) # dx = F* @ dy y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy)
def backward(ctx, grad_output): P_send = ctx.P_send P_recv = ctx.P_recv preserve_batch = ctx.preserve_batch input_tensor_structure = ctx.input_tensor_structure output_tensor_structure = ctx.output_tensor_structure dtype = ctx.dtype input_requires_grad = input_tensor_structure[0] input_tensor_shape = input_tensor_structure[2] output_tensor_shape = output_tensor_structure[2] # This allows all ranks to use the same exit path, so that we can be # sure that all requests have cleared. if preserve_batch: grad_input = zero_volume_tensor(grad_output.shape[0]) else: grad_input = zero_volume_tensor() requests = [] # If I received data (either from a remote worker or just from myself) # I need to reduce that data. If I send and receive to myself, this # is OK, as the reduction accounts for the copy, unlike the broadcast # above. if P_recv.active: reduced_data_recv = np.zeros(output_tensor_shape, dtype=dtype) grad_output_numpy = grad_output.detach().numpy() req = P_recv.comm.Ireduce(grad_output_numpy, reduced_data_recv, root=0, op=MPI.SUM) requests.append(req) # If I sent data in the forward, I have to receive it here. Unless I # also received that data, then I already have it from abive. mpi4py # does not allow aliasing of the input, so we have to make a copy of # nothing, unfortunately. if P_send != P_recv and P_send.active: reduced_data_send = np.zeros(input_tensor_shape, dtype=dtype) req = P_send.comm.Ireduce(reduced_data_send.copy(), reduced_data_send, root=0, op=MPI.SUM) requests.append(req) MPI.Request.Waitall(requests) # If we had to receive data, we need to tensorify it. if P_send.active: if P_send == P_recv: grad_input = torch.tensor(reduced_data_recv, requires_grad=input_requires_grad) else: grad_input = torch.tensor(reduced_data_send, requires_grad=input_requires_grad) return grad_input, None, None, None, None, None, None
def test_broadcast_adjoint(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, x_global_shape, transpose_src): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.broadcast import Broadcast from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) # TODO #93: Change this to create a subtensor so we test when local tensors # have different shape. Then, the output size will also be different, which # we will have to get from `y` itself. x_local_shape = np.asarray(x_global_shape) layer = Broadcast(P_x, P_y, transpose_src=transpose_src, preserve_batch=False) x = zero_volume_tensor() if P_x.active: x = torch.Tensor(np.random.randn(*x_local_shape)) x.requires_grad = True dy = zero_volume_tensor() if P_y.active: # Adjoint Input dy = torch.Tensor(np.random.randn(*x_local_shape)) # y = F @ x y = layer(x) # dx = F* @ dy y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy)
def backward(ctx, grad_output): r"""Adjoint function of zero-volume corrector wrapper. This method interfaces to the adjoint of the Jacobian of the forward zero-volume corrector operation. Parameters ---------- ctx : PyTorch context. grad_output : `torch.tensor` Input tensor. Returns ------- output : `grad_output` if `input` was not zero-volume, a zero-volume tensor otherwise. """ sh = ctx.sh if ctx.zero_volume: return zero_volume_tensor(sh[0], device=grad_output.device) else: return grad_output.clone()
def forward(ctx, input, P_send, P_recv, preserve_batch, input_tensor_structure, output_tensor_structure, dtype): ctx.P_send = P_send ctx.P_recv = P_recv ctx.preserve_batch = preserve_batch ctx.input_tensor_structure = input_tensor_structure ctx.output_tensor_structure = output_tensor_structure ctx.dtype = dtype output_requires_grad = output_tensor_structure[0] output_tensor_shape = output_tensor_structure[2] # This allows all ranks to use the same exit path, so that we can be # sure that all requests have cleared. if preserve_batch: output = zero_volume_tensor(input.shape[0]) else: output = zero_volume_tensor() # return output requests = [] # Send all of the data if P_send.active: input_numpy = input.detach().numpy() req = P_send.comm.Ibcast(input_numpy, root=0) requests.append(req) if P_recv.active: # If I also send, make a copy. if P_send == P_recv: output = input.clone() # If I just receive, receive the broadcast else: output = np.zeros(output_tensor_shape, dtype=dtype) req = P_recv.comm.Ibcast(output, root=0) req.Wait() output = torch.tensor(output, requires_grad=output_requires_grad) MPI.Request.Waitall(requests) return output
def backward(ctx, grad_output): partition = ctx.partition sh = ctx.sh if partition.rank == 0: return grad_output.clone(), None else: return zero_volume_tensor(sh[0]), None
def test_excepts_mismatched_nondivisible_tensor(barrier_fence_fixture, comm_split_fixture): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.repartition import Repartition from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # A tensor with size 1 in a dimension cannot be partitioned in that # dimension. (See last dimension of output and tensor.) in_shape = (1, 4, 1, 1) out_shape = (1, 1, 1, 2) x_global_shape = np.array([1, 16, 5, 1]) in_size = np.prod(in_shape) out_size = np.prod(out_shape) # Create the partitions P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size)) P_x = P_x_base.create_cartesian_topology_partition(in_shape) P_y_base = P_world.create_partition_inclusive( np.arange(P_world.size - out_size, P_world.size)) P_y = P_y_base.create_cartesian_topology_partition(out_shape) with pytest.raises(ValueError) as e_info: # noqa: F841 layer = Repartition(P_x, P_y) layer = layer.to(device) # Forward Input x = zero_volume_tensor(device=device) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.randn(*x_local_shape, device=device) x.requires_grad = True layer(x) P_world.deactivate() P_x_base.deactivate() P_x.deactivate() P_y_base.deactivate() P_y.deactivate()
def backward(ctx, grad_output): slices = ctx.slices buffers = ctx.buffers neighbor_ranks = ctx.neighbor_ranks P_x = ctx.P_x if not P_x.active: return zero_volume_tensor(grad_output.shape[0]), None, None, None, None if P_x.size == 1: return grad_output, None, None, None, None grad_output_numpy = grad_output.detach().numpy() dim = P_x.dim for i in reversed(range(dim)): lbs, lgs, rbs, rgs = slices[i] lbb, lgb, rbb, rgb = buffers[i] lrank, rrank = neighbor_ranks[i] if lgb is not None: np.copyto(lgb, grad_output_numpy[lgs].ravel()) grad_output_numpy[lgs] = 0.0 if rgb is not None: np.copyto(rgb, grad_output_numpy[rgs].ravel()) grad_output_numpy[rgs] = 0.0 ltag = 0 rtag = 1 lrecv_req = P_x.comm.Irecv(lbb, source=lrank, tag=rtag) if lbb is not None else MPI.REQUEST_NULL rrecv_req = P_x.comm.Irecv(rbb, source=rrank, tag=ltag) if rbb is not None else MPI.REQUEST_NULL lsend_req = P_x.comm.Isend(lgb, dest=lrank, tag=ltag) if lgb is not None else MPI.REQUEST_NULL rsend_req = P_x.comm.Isend(rgb, dest=rrank, tag=rtag) if rgb is not None else MPI.REQUEST_NULL reqs = [lrecv_req, rrecv_req, lsend_req, rsend_req] n_reqs_completed = 0 while n_reqs_completed < len(reqs): status = MPI.Status() index = MPI.Request.Waitany(reqs, status) if index != MPI.UNDEFINED: if index == 0: newshape = grad_output_numpy[lbs].shape grad_output_numpy[lbs] += lbb.reshape(newshape) elif index == 1: newshape = grad_output_numpy[rbs].shape grad_output_numpy[rbs] += rbb.reshape(newshape) n_reqs_completed += 1 return grad_output, None, None, None, None
def forward(ctx, input, P_x, slices, buffers, neighbor_ranks): ctx.slices = slices ctx.buffers = buffers ctx.neighbor_ranks = neighbor_ranks ctx.P_x = P_x if not P_x.active: return zero_volume_tensor(input.shape[0]) ctx.mark_dirty(input) if P_x.size == 1: return input input_numpy = input.detach().numpy() dim = P_x.dim for i in range(dim): lbs, lgs, rbs, rgs = slices[i] lbb, lgb, rbb, rgb = buffers[i] lrank, rrank = neighbor_ranks[i] if lbb is not None: np.copyto(lbb, input_numpy[lbs].ravel()) if rbb is not None: np.copyto(rbb, input_numpy[rbs].ravel()) ltag = 0 rtag = 1 lrecv_req = P_x.comm.Irecv(lgb, source=lrank, tag=rtag) if lgb is not None else MPI.REQUEST_NULL rrecv_req = P_x.comm.Irecv(rgb, source=rrank, tag=ltag) if rgb is not None else MPI.REQUEST_NULL lsend_req = P_x.comm.Isend(lbb, dest=lrank, tag=ltag) if lbb is not None else MPI.REQUEST_NULL rsend_req = P_x.comm.Isend(rbb, dest=rrank, tag=rtag) if rbb is not None else MPI.REQUEST_NULL reqs = [lrecv_req, rrecv_req, lsend_req, rsend_req] n_reqs_completed = 0 while n_reqs_completed < len(reqs): status = MPI.Status() index = MPI.Request.Waitany(reqs, status) if index != MPI.UNDEFINED: if index == 0: newshape = input_numpy[lgs].shape np.copyto(input_numpy[lgs], lgb.reshape(newshape)) elif index == 1: newshape = input_numpy[rgs].shape np.copyto(input_numpy[rgs], rgb.reshape(newshape)) n_reqs_completed += 1 return input
def backward(ctx, grad_output): P_send = ctx.P_send P_recv = ctx.P_recv preserve_batch = ctx.preserve_batch input_tensor_structure = ctx.input_tensor_structure dtype = ctx.dtype input_requires_grad = input_tensor_structure[0] input_tensor_shape = input_tensor_structure[2] # This allows all ranks to use the same exit path, so that we can be # sure that all requests have cleared. if preserve_batch: grad_input = zero_volume_tensor(grad_output.shape[0]) else: grad_input = zero_volume_tensor() requests = [] # If I received the reduction in the forward call, I broadcast my data if P_recv.active: grad_output_numpy = grad_output.detach().numpy() req = P_recv.comm.Ibcast(grad_output_numpy, root=0) requests.append(req) # If I just receive, receive the broadcast if P_send.active: # If I both sent and received reduction data, then I copy the "input" if P_send == P_recv: grad_input = grad_output.clone() else: grad_input = np.zeros(input_tensor_shape, dtype=dtype) req = P_send.comm.Ibcast(grad_input, root=0) req.Wait() grad_input = torch.tensor(grad_input, requires_grad=input_requires_grad) MPI.Request.Waitall(requests) return grad_input, None, None, None, None, None, None
def backward(ctx, grad_output): r"""Backward function of distributed all-sum-reduction layer. This method implements the adjoint of the Jacobian of the all-sum-reduce operation, another all-sum-reduce, using the ``MPI_Iallreduce`` function. When the current worker is inactive in the ``P_allreduce`` partition, it will output a zero-volume tensor. Parameters ---------- ctx : PyTorch context. grad_output : `torch.tensor` Input tensor. Returns ------- grad_input : Output tensor. """ P_allreduce = ctx.P_allreduce input_tensor_structure = ctx.input_tensor_structure device = ctx.device grad_input = zero_volume_tensor(device=device) requests = [] # All-sum-reduce is self-adjoint if P_allreduce.active: numpy_dtype = torch_to_numpy_dtype_dict[ input_tensor_structure.dtype] reduced_data = np.zeros(input_tensor_structure.shape, dtype=numpy_dtype) grad_output_numpy = grad_output.detach().cpu().numpy() req = P_allreduce._comm.Iallreduce(grad_output_numpy, reduced_data, op=MPI.SUM) requests.append(req) MPI.Request.Waitall(requests) # If we had to receive data, we need to tensorify it. if P_allreduce.active: grad_input = torch.tensor( reduced_data, requires_grad=input_tensor_structure.requires_grad, device=device) return grad_input, None, None, None
def test_simple_conv2d_shape(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape, y_local_shape, padding): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.conv_feature import DistributedFeatureConv2d from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) x_global_shape = np.asarray(x_global_shape) layer = DistributedFeatureConv2d(P_x, in_channels=x_global_shape[1], out_channels=10, kernel_size=[3, 3], padding=padding, bias=False) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.zeros(*x_local_shape) x.requires_grad = True y = layer(x) if P_x.active: assert(np.array_equal(np.array(y.shape), np.asarray(y_local_shape))) P_world.deactivate() P_x_base.deactivate() P_x.deactivate()
def test_excepts_mismatched_output_partition_tensor(barrier_fence_fixture, comm_split_fixture): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.transpose import DistributedTranspose from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Output partition rank must match tensor rank in_shape = (4, 1, 1) out_shape = (1, 1, 1, 2) x_global_shape = np.array([16, 5, 5]) in_size = np.prod(in_shape) out_size = np.prod(out_shape) # Create the partitions P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size)) P_x = P_x_base.create_cartesian_topology_partition(in_shape) P_y_base = P_world.create_partition_inclusive(np.arange(P_world.size-out_size, P_world.size)) P_y = P_y_base.create_cartesian_topology_partition(out_shape) with pytest.raises(ValueError) as e_info: # noqa: F841 layer = DistributedTranspose(P_x, P_y) # Forward Input x = zero_volume_tensor() if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.Tensor(np.random.randn(*x_local_shape)) x.requires_grad = True layer(x)
def test_batch_norm_no_training(barrier_fence_fixture, P_x_ranks, P_x_shape, input_shape, num_features, eps, momentum, affine, track_running_stats, comm_split_fixture): from distdl.backends.mpi.partition import MPIPartition device = torch.device('cuda' if use_cuda else 'cpu') torch.manual_seed(0) # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions num_dimensions = len(input_shape) P_in_out_base = P_world.create_partition_inclusive([0]) P_in_out = P_in_out_base.create_cartesian_topology_partition( [1] * num_dimensions) P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) # Create the input if P_world.rank == 0: input_eval = torch.rand(input_shape, dtype=torch.float32, device=device) else: input_eval = zero_volume_tensor(device=device) # Create the sequential network if P_world.rank == 0: seq_net = torch.nn.Sequential( torch.nn.BatchNorm1d(num_features=num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)) seq_net = seq_net.to(device) else: seq_net = None # Evaluate sequential network if P_world.rank == 0: seq_net.eval() seq_out = seq_net(input_eval) seq_out = seq_out.detach().cpu() # Create distributed network dist_net = torch.nn.Sequential( distdl.nn.Repartition(P_in_out, P_x), distdl.nn.DistributedBatchNorm( P_x, num_features=num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats), distdl.nn.Repartition(P_x, P_in_out)) dist_net = dist_net.to(device) # Evaluate distributed network dist_net.eval() dist_out = dist_net(input_eval) dist_out = dist_out.detach().cpu() # Compare the distributed and sequential networks if P_world.rank == 0: assert dist_out.shape == seq_out.shape # Set the absolute tolerance to ~sqrt(e_mach), or the default # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit # floats, not 32-bit floats as torch does. Consequently, the default # torch atol is actually tighter than one can expect from two fp-equal # floating point numbers. The NumPy default of 1e-8 is closer to # sqrt(e_mach) for 64-bit numbers. So we set the 32-bit tolerance to # a little tighter than sqrt(1e-7), 1e-5. if seq_out.dtype == torch.float64: atol = 1e-8 elif seq_out.dtype == torch.float32: atol = 1e-5 else: # torch default atol = 1e-8 assert torch.allclose(dist_out, seq_out, atol=atol) P_world.deactivate() P_x_base.deactivate() P_x.deactivate() P_in_out_base.deactivate() P_in_out.deactivate()
def forward(ctx, input, P_send, P_recv, preserve_batch, input_tensor_structure, output_tensor_structure): r"""Forward function of distributed sum-reduction layer. This method implements the forward sum-reduction operation using the ``MPI_Ireduce`` function. Any given worker may participate in two MPI reductions, one on the ``P_send`` partition and one on the ``P_recv`` partition. The communication pattern and function selection is to avoid potential deadlocks due to potential overlaps in these partitions. When the current worker is active in its ``P_send`` partition, it *always* has data that must be reduced. Therefore it will always send data (through a sum-reduce) to the root of that partition. If the current worker is active in ``P_recv`` then it is guaranteed to be the root worker of ``P_recv`` and there are two potential scenerios. 1. If the ``P_send`` and ``P_recv`` partitions are distinct, the current worker will receive reduced tensor data as the root of an additional ``MPI_Ireduce``. 2. If the ``P_send`` and ``P_recv`` partitions are the same, the reduction is completed by the *first* ``MPI_Ireduce`` and the second is not necessary, and in fact will cause a deadlock. When the current worker is inactive in the ``P_recv`` partition, it will output a zero-volume tensor, potentially preserving a non-zero batch size. Parameters ---------- ctx : PyTorch context. input : `torch.tensor` Input tensor. P_send : Partition Sending partition current worker is a part of. P_recv : Partition Receiving partition current worker is a part of. preserve_batch : bool Indicates if batch size should be preserved for zero-volume outputs. input_tensor_structure : tuple Tuple containing properties of the input tensor (dimension, shape, requires_grad). output_tensor_structure : tuple Tuple containing properties of the output tensor (dimension, shape, requires_grad). Returns ------- output : Output tensor. """ device = input.device ctx.P_send = P_send ctx.P_recv = P_recv ctx.preserve_batch = preserve_batch ctx.input_tensor_structure = input_tensor_structure ctx.output_tensor_structure = output_tensor_structure ctx.device = device # This allows all ranks to use the same exit path, so that we can be # sure that all requests have cleared. if preserve_batch: output = zero_volume_tensor(input.shape[0], device=device) else: output = zero_volume_tensor(device=device) requests = [] # By design, the roots are always 0 in the cross-communicators # If I receive data (either from a remote worker or just from myself) # I need to reduce that data. If I send and receive to myself, this # is OK, as the reduction accounts for the copy, unlike the broadcast # below. if P_send.active: numpy_dtype = torch_to_numpy_dtype_dict[ input_tensor_structure.dtype] reduced_data_send = np.zeros(input_tensor_structure.shape, dtype=numpy_dtype) input_numpy = input.detach().cpu().numpy() req = P_send._comm.Ireduce(input_numpy, reduced_data_send, root=0, op=MPI.SUM) requests.append(req) # If I sent data in the forward, I have to receive it here. if P_send != P_recv and P_recv.active: numpy_dtype = torch_to_numpy_dtype_dict[ output_tensor_structure.dtype] reduced_data_recv = np.zeros(output_tensor_structure.shape, dtype=numpy_dtype) req = P_recv._comm.Ireduce(MPI.IN_PLACE, reduced_data_recv, root=0, op=MPI.SUM) requests.append(req) MPI.Request.Waitall(requests) # If we had to receive data, we need to tensorify it. if P_recv.active: if P_send == P_recv: output = torch.tensor( reduced_data_send, requires_grad=output_tensor_structure.requires_grad, device=device) else: output = torch.tensor( reduced_data_recv, requires_grad=output_tensor_structure.requires_grad, device=device) return output
def test_batch_norm_with_training(barrier_fence_fixture, P_x_ranks, P_x_shape, input_shape, num_features, eps, momentum, affine, track_running_stats, affine_workers, comm_split_fixture): from distdl.backends.mpi.partition import MPIPartition torch.manual_seed(0) device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions num_dimensions = len(input_shape) P_in_out_base = P_world.create_partition_inclusive([0]) P_in_out = P_in_out_base.create_cartesian_topology_partition( [1] * num_dimensions) P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_affine = P_world.create_partition_inclusive(affine_workers) # Create the input if P_world.rank == 0: input_train = torch.rand(input_shape, dtype=torch.float32, device=device) input_eval = torch.rand(input_shape, dtype=torch.float32, device=device) exp = torch.rand(input_shape, dtype=torch.float32, device=device) else: input_train = zero_volume_tensor(device=device) input_eval = zero_volume_tensor(device=device) exp = zero_volume_tensor(device=device) # Create the sequential network if len(input_shape) == 2: seq_layer = torch.nn.BatchNorm1d elif len(input_shape) == 3: seq_layer = torch.nn.BatchNorm1d elif len(input_shape) == 4: seq_layer = torch.nn.BatchNorm2d elif len(input_shape) == 5: seq_layer = torch.nn.BatchNorm3d if P_world.rank == 0: seq_bn = seq_layer(num_features=num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) seq_bn = seq_bn.to(device) # Train sequential network if P_world.rank == 0: seq_bn.train() seq_out1 = seq_bn(input_train) seq_loss = ((seq_out1 - exp)**2).sum() seq_loss.backward() seq_grads = [p.grad.detach().cpu() for p in seq_bn.parameters()] # Do a manual weight update (this is what optimizer does): with torch.no_grad(): for p in seq_bn.parameters(): p.copy_(p + 0.1 * p.grad) # Evaluate sequential network if P_world.rank == 0: seq_bn.eval() seq_out2 = seq_bn(input_eval) seq_out2 = seq_out2.detach().cpu() # Create distributed network tr1 = distdl.nn.Repartition(P_in_out, P_x) tr1 = tr1.to(device) dist_bn = distdl.nn.DistributedBatchNorm( P_x, num_features=num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) dist_bn = dist_bn.to(device) tr2 = distdl.nn.Repartition(P_x, P_in_out) tr2 = tr2.to(device) # Only rank 0 should have trainable parameters: if P_world.rank in affine_workers: assert len(list(dist_bn.parameters())) == 2 else: assert len(list(dist_bn.parameters())) == 0 # Train distributed network dist_bn.train() dist_out1 = tr2(dist_bn(tr1(input_train))) dist_loss = ((dist_out1 - exp)**2).sum() assert dist_loss.requires_grad dist_loss.backward() # Note: We expect the batch norm gradient to have extra dimensions than PyTorch, # but both ultimately have volume equal to num_features. # So, reshape them now, and gather them onto rank 0 for comparison. if P_world.rank == 0: dist_grads = [] if P_world.rank in affine_workers: for p in dist_bn.parameters(): if affine_workers == [0]: parts = [p.grad.detach().cpu()] else: parts = P_affine._comm.gather(p.grad.detach().cpu(), root=0) if P_world.rank == 0: grad = torch.cat(parts, 1) reshaped = grad.reshape((num_features, )) dist_grads.append(reshaped) # Do a manual weight update (this is what optimizer does): with torch.no_grad(): for p in dist_bn.parameters(): p.copy_(p + 0.1 * p.grad) # Evaluate distributed network dist_bn.eval() dist_out2 = tr2(dist_bn(tr1(input_eval))) dist_out2 = dist_out2.detach().cpu() # Compare the distributed and sequential networks if P_world.rank == 0: # Set the absolute tolerance to ~sqrt(e_mach), or the default # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit # floats, not 32-bit floats as torch does. Consequently, the default # torch atol is actually tighter than one can expect from two fp-equal # floating point numbers. The NumPy default of 1e-8 is closer to # sqrt(e_mach) for 64-bit numbers. So we set the 32-bit tolerance to # sqrt(1e-7), as our usual choice, 1e-5, is too tight. if seq_out1.dtype == torch.float64: atol = 1e-8 elif seq_out1.dtype == torch.float32: import math atol = math.sqrt(1e-7) else: # torch default atol = 1e-8 assert dist_out1.shape == seq_out1.shape assert torch.allclose(dist_out1, seq_out1, rtol=ERROR_THRESHOLD, atol=atol) assert dist_loss.shape == seq_loss.shape assert torch.allclose(dist_loss, seq_loss, rtol=ERROR_THRESHOLD, atol=atol) for dist_grad, seq_grad in zip(dist_grads, seq_grads): assert dist_grad.shape == seq_grad.shape assert torch.allclose(dist_grad, seq_grad, rtol=ERROR_THRESHOLD, atol=atol) assert dist_out2.shape == seq_out2.shape assert torch.allclose(dist_out2, seq_out2, rtol=ERROR_THRESHOLD, atol=atol) P_world.deactivate() P_x_base.deactivate() P_x.deactivate() P_in_out_base.deactivate() P_in_out.deactivate() P_affine.deactivate()
def backward(ctx, grad_output): r"""Backward function of distributed sum-reduction layer. This method implements the adjoint of the Jacobian of the sum-reduce operation, a sum-reduce, using the ``MPI_Ibcast`` function. The roles of the respective send and receive partitions are reversed in the adjoint algorithm. Any worker that was the source of reduced data in the forward algorithm will be the destination of broadcast data in the adjoint. Any given worker may participate in two MPI broadcasts, one on the ``P_recv`` partition and one on the ``P_send`` partition. The communication pattern and function selection is to avoid potential deadlocks due to potential overlaps in these partitions. When the current worker is active in its ``P_recv`` partition, it *always* has data that it must share. It will only be active in ``P_recv`` if it is the root worker of that partition, therefore, it will send tensor data as the root of an ``MPI_Ibcast``. When the current worker is active in its ``P_send`` partition, there are multiple potential scenerios. 1. If it is *active* in a ``P_recv`` partition and ``P_recv`` is the *same* partition as ``P_send``, then the input subtensor can simply be cloned for the output. 2. If it is *active* in a ``P_recv`` partition and ``P_recv`` is a *different* partition from ``P_send``, then it will receive tensor data from the root of an ``MPI_Ibcast``. 3. If it is *inactive* in a ``P_recv`` partition, then it will receive tensor data from the root of an ``MPI_Ibcast``. When the current worker is inactive in the ``P_send`` partition, it will output a zero-volume tensor, potentially preserving a non-zero batch size. Parameters ---------- ctx : PyTorch context. grad_output : `torch.tensor` Input tensor. Returns ------- grad_input : Output tensor. """ P_send = ctx.P_send P_recv = ctx.P_recv preserve_batch = ctx.preserve_batch input_tensor_structure = ctx.input_tensor_structure device = ctx.device assert grad_output.device == device # This allows all ranks to use the same exit path, so that we can be # sure that all requests have cleared. if preserve_batch: grad_input = zero_volume_tensor(grad_output.shape[0], device=device) else: grad_input = zero_volume_tensor(device=device) requests = [] # If I received the reduction in the forward call, I broadcast my data if P_recv.active: grad_output_numpy = grad_output.detach().cpu().numpy() req = P_recv._comm.Ibcast(grad_output_numpy, root=0) requests.append(req) # If I just receive, receive the broadcast if P_send.active: # If I both sent and received reduction data, then I copy the "input" if P_send == P_recv: grad_input = grad_output.clone() else: numpy_dtype = torch_to_numpy_dtype_dict[ input_tensor_structure.dtype] grad_input = np.zeros(input_tensor_structure.shape, dtype=numpy_dtype) req = P_send._comm.Ibcast(grad_input, root=0) req.Wait() grad_input = torch.tensor( grad_input, requires_grad=input_tensor_structure.requires_grad, device=device) MPI.Request.Waitall(requests) return grad_input, None, None, None, None, None
def test_distributed_loss(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape, SequentialLoss, DistributedLoss): import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn import DistributedTranspose from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_0_base = P_x_base.create_partition_inclusive([0]) P_0 = P_0_base.create_cartesian_topology_partition([1] * len(P_x_shape)) scatter = DistributedTranspose(P_0, P_x).to(device) gather = DistributedTranspose(P_x, P_0).to(device) for reduction in DistributedLoss._valid_reductions: distributed_criterion = DistributedLoss(P_x, reduction=reduction).to(device) sequential_criterion = SequentialLoss(reduction=reduction).to(device) with torch.no_grad(): x_g = zero_volume_tensor(device=device) y_g = zero_volume_tensor(device=device) if P_0.active: x_g = torch.rand(x_global_shape, device=device) y_g = torch.rand(x_global_shape, device=device) x_l = scatter(x_g) y_l = scatter(y_g) x_l.requires_grad = True distributed_loss = distributed_criterion(x_l, y_l) # For "none", no reduction is applied so we see if it computed the # same loss as the sequential code by gathering the loss value it to # the root rank. if reduction == "none": distributed_loss = gather(distributed_loss) if P_0.active: x_g.requires_grad = True sequential_loss = sequential_criterion(x_g, y_g) assert (torch.allclose(distributed_loss, sequential_loss)) # For any other reduction, we can compare the loss # value *and* backpropagate through the distributed loss to verify # that it produces the same output. if reduction != "none": distributed_loss.backward() distributed_dx_g = gather(x_l.grad) if P_0.active: sequential_loss.backward() sequential_dx_g = x_g.grad assert (torch.allclose(distributed_dx_g, sequential_dx_g)) P_world.deactivate() P_x_base.deactivate() P_x.deactivate() P_0_base.deactivate() P_0.deactivate()
def test_sum_reduce_dtype(barrier_fence_fixture, comm_split_fixture, dtype, test_backward, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, x_global_shape, transpose_src): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.sum_reduce import SumReduce from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) # TODO #93: Change this to create a subtensor so we test when local tensors # have different shape. Then, the output size will also be different, which # we will have to get from `y` itself. x_local_shape = np.asarray(x_global_shape) layer = SumReduce(P_x, P_y, transpose_src=transpose_src, preserve_batch=False) layer = layer.to(device) x = zero_volume_tensor(device=device) if P_x.active: x = 10 * torch.randn(*x_local_shape, device=device).to(dtype) x.requires_grad = test_backward # y = F @ x y = layer(x) # If we are not in the output partition, there is no data to test the type # against. if P_y.active: assert y.dtype == dtype if test_backward: dy = zero_volume_tensor(device=device) if P_y.active: # Adjoint Input dy = 10 * torch.randn(*x_local_shape, device=device).to(dtype) # dx = F* @ dy y.backward(dy) dx = x.grad if P_x.active: assert dx.dtype == dtype P_world.deactivate() P_x_base.deactivate() P_x.deactivate() P_y_base.deactivate() P_y.deactivate()
def backward(ctx, grad_output): r"""Adjoint function of distributed transpose layer. This method implements the adjoint of the Jacobian of the transpose operation using MPI immediate-mode, non-blocking communication. The roles of the ``P_x`` and ``P_y`` partitions are reversed, but all communication across partitions occurs through the ``P_union`` partition. Data is copied using ``MPI_Irecv`` and ``MPI_Isend``. As is standard procedure, the receives are posted first, allowing them to complete as they can. Then, buffers are packed and sent. Once all sends have been posted, received data is unpacked in the order that the receives complete. When the current worker is inactive in the ``P_x`` partition, it will output a zero-volume tensor, potentially preserving a non-zero batch Parameters ---------- ctx : PyTorch context. grad_output : `torch.tensor` Input tensor. Returns ------- output : Output tensor. """ P_union = ctx.P_union x_global_structure = ctx.x_global_structure x_local_structure = ctx.x_local_structure P_x = ctx.P_x P_x_to_y_overlaps = ctx.P_x_to_y_overlaps P_x_to_y_buffers = ctx.P_x_to_y_buffers P_y = ctx.P_y P_y_to_x_overlaps = ctx.P_y_to_x_overlaps P_y_to_x_buffers = ctx.P_y_to_x_buffers preserve_batch = ctx.preserve_batch input_requires_grad = ctx.input_requires_grad device = ctx.device assert grad_output.device == device requests = [] # Default everyone to output None if preserve_batch: grad_input = zero_volume_tensor(grad_output.shape[0], dtype=x_global_structure.dtype, device=device) else: grad_input = zero_volume_tensor(dtype=x_global_structure.dtype, device=device) # Recv my input parts recv_count = 0 if P_x.active: for (sl, sh, partner), buff in zip(P_x_to_y_overlaps, P_x_to_y_buffers): if buff is not None: xfer_buff = buff.get_view(sh) req = P_union._comm.Irecv(xfer_buff, source=partner, tag=113) requests.append(req) else: # We add this if there is no recv so that the indices of # the requests array match the indices of # P_x_to_y_overlaps and P_x_to_y_buffers. requests.append(MPI.REQUEST_NULL) recv_count += 1 # Pack and send my input parts send_count = 0 if P_y.active: for (sl, sh, partner), buff in zip(P_y_to_x_overlaps, P_y_to_x_buffers): if buff is not None: xfer_buff = buff.get_view(sh) np.copyto(xfer_buff, grad_output.detach()[sl].cpu().numpy()) req = P_union._comm.Isend(xfer_buff, dest=partner, tag=113) requests.append(req) else: # We add this for symmetry, but don't really need it. requests.append(MPI.REQUEST_NULL) send_count += 1 if P_x.active: numpy_dtype = torch_to_numpy_dtype_dict[x_global_structure.dtype] grad_input = np.zeros(x_local_structure.shape, dtype=numpy_dtype) # Handle the self-copy if P_y.active and P_x.active: # Find the self patch in x_to_y for (ysl, ysh, y2xpartner) in P_y_to_x_overlaps: if y2xpartner == "self": for (xsl, xsh, x2ypartner) in P_x_to_y_overlaps: if x2ypartner == "self": np.copyto(grad_input[xsl], grad_output.detach()[ysl].cpu().numpy()) # There is only one case where this can happen break # There is only one case where this can happen break # Unpack the received data as it arrives completed_count = 0 while (completed_count < len(requests)): status = MPI.Status() index = MPI.Request.Waitany(requests, status) # In MPI, we don't get the index out if the request is an # instance of MPI.REQUEST_NULL, instead MPI.UNDEFINED is returned. if P_x.active and index < recv_count and index != MPI.UNDEFINED: # Unpack my output parts sl, sh, partner = P_x_to_y_overlaps[index] buff = P_x_to_y_buffers[index] if buff is not None: xfer_buff = buff.get_view(sh) # This would normally be an add into the grad_input tensor # but we just created it, so a copy is sufficient. np.copyto(grad_input[sl], xfer_buff) completed_count += 1 if P_x.active: grad_input = torch.tensor(grad_input, requires_grad=input_requires_grad, device=device) return grad_input, None, None, None, None, None, None, None, None, None, None, None
# Create the transpose layer layer = Repartition(P_x, P_y, preserve_batch=False) # Setup the input tensor. Any worker in P_x will generate its part of the # input tensor. Any worker not in P_x will have a zero-volume tensor. # # Input tensor will be (on a 1 x 1 partition): # [ [ 1 1 1 1 1 ] # [ 1 1 1 1 1 ] # [ 1 1 1 1 1 ] # [ 1 1 1 1 1 ] # [ 1 1 1 1 1 ] # [ 1 1 1 1 1 ] # [ 1 1 1 1 1 ] ] x = zero_volume_tensor() if P_x.active: x_local_shape = slicing.compute_subshape(P_x.shape, P_x.index, x_global_shape) x = np.zeros(x_local_shape) + P_x.rank + 1 x = torch.from_numpy(x) x.requires_grad = True print(f"rank {P_world.rank}; index {P_x.index}; value {x}") # Apply the layer. # # Output tensor will be (on a 2 x 2 partition): # [ [ 1 1 1 | 1 1 ] # [ 1 1 1 | 1 1 ] # [ 1 1 1 | 1 1 ] # [ 1 1 1 | 1 1 ]
def forward(ctx, input, P_union, x_global_structure, x_local_structure, y_local_structure, P_x, P_x_to_y_overlaps, P_x_to_y_buffers, P_y, P_y_to_x_overlaps, P_y_to_x_buffers, preserve_batch): r"""Forward function of distributed transpose layer. This method implements the forward transpose operation using MPI immediate-mode, non-blocking communication. Any given worker may send data to multiple workers in ``P_y`` and receive data from multiple workers in ``P_x``. All communication across partitions occurs through the ``P_union`` partition. Data is copied using ``MPI_Irecv`` and ``MPI_Isend``. As is standard procedure, the receives are posted first, allowing them to complete as they can. Then, buffers are packed and sent. Once all sends have been posted, received data is unpacked in the order that the receives complete. When the current worker is inactive in the ``P_y`` partition, it will output a zero-volume tensor, potentially preserving a non-zero batch size. Parameters ---------- ctx : PyTorch context. input : `torch.tensor` Input tensor. P_union : Partition Partition through which all communication occurs. x_global_structure : Structure of the global input tensor. x_local_structure : Structure of the local input tensor. y_local_structure : Structure of the local output tensor. P_x : Partition Input partition. P_x_to_y_overlaps : list List of tuples (sl, sh, partner) for each send current worker must perform. P_x_to_y_buffers : list List of pre-allocated send buffers for each send current worker must perform. P_y : Partition Input partition. P_y_to_x_overlaps : list List of tuples (sl, sh, partner) for each receive current worker must perform. P_y_to_x_buffers : list List of pre-allocated send buffers for each receive current worker must perform. preserve_batch : bool Indicates if batch size should be preserved for zero-volume outputs. Returns ------- output : Output tensor. """ ctx.P_union = P_union ctx.x_global_structure = x_global_structure ctx.x_local_structure = x_local_structure ctx.P_x = P_x ctx.P_x_to_y_overlaps = P_x_to_y_overlaps ctx.P_x_to_y_buffers = P_x_to_y_buffers ctx.P_y = P_y ctx.P_y_to_x_overlaps = P_y_to_x_overlaps ctx.P_y_to_x_buffers = P_y_to_x_buffers ctx.preserve_batch = preserve_batch device = input.device ctx.device = device input_requires_grad = False # Share the requires-grad status, so that it is preserved across the # transpose if P_union.active: # By design, P_x is always first in the union, so we can just take # rank 0's status to send if P_x.rank == 0: input_requires_grad = input.requires_grad P_union._comm.Bcast(np.array([1 if input_requires_grad else 0 ]), root=0) else: irg = np.array([0], dtype=np.int) P_union._comm.Bcast(irg, root=0) input_requires_grad = bool(irg[0] == 1) ctx.input_requires_grad = input_requires_grad requests = [] # Default everyone to output nothing if preserve_batch: output = zero_volume_tensor(input.shape[0], dtype=x_global_structure.dtype, device=device) else: output = zero_volume_tensor(dtype=x_global_structure.dtype, device=device) # If I am getting data, recv my output parts recv_count = 0 if P_y.active: for (sl, sh, partner), buff in zip(P_y_to_x_overlaps, P_y_to_x_buffers): if buff is not None: xfer_buff = buff.get_view(sh) req = P_union._comm.Irecv(xfer_buff, source=partner, tag=111) requests.append(req) else: # We add this if there is no recv so that the indices of # the requests array match the indices of # P_y_to_x_overlaps and P_y_to_x_buffers. requests.append(MPI.REQUEST_NULL) recv_count += 1 # If I have data to share, pack and send my input parts send_count = 0 if P_x.active: for (sl, sh, partner), buff in zip(P_x_to_y_overlaps, P_x_to_y_buffers): if buff is not None: xfer_buff = buff.get_view(sh) np.copyto(xfer_buff, input.detach()[sl].cpu().numpy()) req = P_union._comm.Isend(xfer_buff, dest=partner, tag=111) requests.append(req) else: # We add this for symmetry, but don't really need it. requests.append(MPI.REQUEST_NULL) send_count += 1 # We do this after the sends so that they can get started before local # allocations. if P_y.active: numpy_dtype = torch_to_numpy_dtype_dict[x_global_structure.dtype] output = np.zeros(y_local_structure.shape, dtype=numpy_dtype) # Handle the self-copy if P_x.active and P_y.active: # Find the self patch in x_to_y for (xsl, xsh, x2ypartner) in P_x_to_y_overlaps: if x2ypartner == "self": for (ysl, ysh, y2xpartner) in P_y_to_x_overlaps: if y2xpartner == "self": np.copyto(output[ysl], input.detach()[xsl].cpu().numpy()) # There is only one case where this can happen break # There is only one case where this can happen break # Unpack the received data as it arrives completed_count = 0 while (completed_count < len(requests)): status = MPI.Status() index = MPI.Request.Waitany(requests, status) # In MPI, we don't get the index out if the request is an # instance of MPI.REQUEST_NULL, instead MPI.UNDEFINED is returned. if P_y.active and index < recv_count and index != MPI.UNDEFINED: # Unpack my output parts sl, sh, partner = P_y_to_x_overlaps[index] buff = P_y_to_x_buffers[index] if buff is not None: xfer_buff = buff.get_view(sh) np.copyto(output[sl], xfer_buff) completed_count += 1 if P_y.active: output = torch.tensor(output, requires_grad=input_requires_grad, device=device) return output
def __init__(self, P_x, P_y, P_w, in_channels=1, out_channels=1, bias=True, *args, **kwargs): super(DistributedGeneralConvBase, self).__init__() # P_x is 1 x P_ci x P_d-1 x ... x P_0 self.P_x = P_x # P_y is 1 x P_co x P_d-1 x ... x P_0 self.P_y = P_y # P_w is P_co x P_ci x P_d-1 x ... x P_0 self.P_w = P_w self.P_union = self._distdl_backend.Partition() if not (self.P_x.active or self.P_y.active or self.P_w.active): return # This guarantees that P_union rank 0 has the kernel size, stride, # padding, and dilation factors P_union = P_w.create_partition_union(P_x) P_union = P_union.create_partition_union(P_y) self.P_union = P_union P_w_shape = None if P_union.rank == 0: P_w_shape = np.array(P_w.shape, dtype=np.int) P_w_shape = P_union.broadcast_data(P_w_shape, root=0) P_co = P_w_shape[0] P_ci = P_w_shape[1] P_channels = [P_co, P_ci] P_x_new_shape = [] if self.P_x.active: if (np.any(P_x.shape[2:] != P_w_shape[2:])): raise ValueError( "Spatial components of P_x and P_w must match.") if P_w_shape[1] != P_x.shape[1]: raise ValueError( "Index 2 of P_w dimension must match input channel partition." ) P_x_new_shape = list(P_x.shape) P_x_new_shape.insert(1, 1) # Currently a hack, removing the batch dimension because P_w does # not have one. This is OK because we assume there are no partitions # in the batch dimension. P_x_new_shape = np.asarray(P_x_new_shape[1:], dtype=int) # For the purposes of this layer, we re-cast P_x to have the extra # dimension. This has no impact outside of the layer or on the results. self.P_x = self.P_x.create_cartesian_topology_partition(P_x_new_shape) P_y_new_shape = [] if self.P_y.active: if (np.any(P_y.shape[2:] != P_w_shape[2:])): raise ValueError( "Spatial components of P_y and P_w must match.") if P_w_shape[0] != P_y.shape[1]: raise ValueError( "Index 1 of P_w dimension must match output channel partition." ) P_y_new_shape = list(P_y.shape) P_y_new_shape.insert(2, 1) # Currently a hack, removing the batch dimension because P_w does # not have one. This is OK because we assume there are no partitions # in the batch dimension. P_y_new_shape = np.asarray(P_y_new_shape[1:], dtype=int) # For the purposes of this layer, we re-cast P_x to have the extra # dimension. This has no impact outside of the layer or on the results. self.P_y = self.P_y.create_cartesian_topology_partition(P_y_new_shape) P_spatial = P_w_shape[2:] self.serial = False if self.P_w.size == 1: self.serial = True self.conv_layer = self.TorchConvType(*args, **kwargs) return self.receives_weight = False self.stores_weight = False self.receives_bias = False self.stores_bias = False # Determine P_r, initialize weights there if self.P_w.active: # All of P_w always receives the weight self.receives_weight = True # This subset is taken to be the origin of the spartial component w_root_subset = [] for i, c in enumerate(range_index(P_w.shape)): c = np.asarray(c) # Find the P_co x P_ci x 1 x ... x 1 subset to store the weights if np.all(c[2:] == 0): w_root_subset.append(i) self.P_wr_base = self.P_w.create_partition_inclusive(w_root_subset) # ones are needed so the broadcast will work self.P_wr = self.P_wr_base.create_cartesian_topology_partition( [P_co, P_ci] + [1] * len(P_spatial)) self.stores_weight = self.P_wr.active b_subset = [] for i, c in enumerate(range_index(P_w.shape)): c = np.asarray(c) # Find the P_co x 1 x P_0 x ... x P_D-1 subset that needs biases in its calculation. # This is everywhere that the input channels is rank 0. if c[1] == 0: b_subset.append(i) self.P_b_base = self.P_w.create_partition_inclusive(b_subset) self.P_b = self.P_b_base.create_cartesian_topology_partition( [P_co] + [1] + list(P_spatial)) self.receives_bias = self.P_b.active and bias # Now find the subset of _that_ which actually stores the learnable parameter. b_root_subset = [] for i, c in enumerate(range_index(P_w.shape)): c = np.asarray(c) # Find the P_co x 1 x 1 x ... x 1 subset to store the biases if np.all(c[1:] == 0): b_root_subset.append(i) self.P_br_base = self.P_w.create_partition_inclusive(b_root_subset) # ones are needed so the broadcast will work self.P_br = self.P_br_base.create_cartesian_topology_partition( [P_co] + [1] + [1] * len(P_spatial)) self.stores_bias = self.P_br.active and bias # Correct the input arguments based on local properties local_kwargs = {} local_kwargs.update(kwargs) # Do this before checking serial so that the layer works properly # in the serial case local_channels = compute_subshape(P_channels, P_w.index[0:2], [out_channels, in_channels]) local_out_channels, local_in_channels = local_channels local_kwargs["in_channels"] = local_in_channels local_kwargs["out_channels"] = local_out_channels local_kwargs["bias"] = self.receives_bias self.conv_layer = self.TorchConvType(*args, **local_kwargs) # If we store the weight it is a learnable parameter iff it is # learnable by default in the layer, which it is. if self.stores_weight: self._weight = torch.nn.Parameter( self.conv_layer.weight.detach()) else: self._weight = zero_volume_tensor() # This always exists so we can copy the property self._weight.requires_grad = self.conv_layer.weight.requires_grad # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2 new_weight = self.conv_layer.weight.detach() * 0 new_weight.requires_grad = self.conv_layer.weight.requires_grad del self.conv_layer.weight self.conv_layer.weight = new_weight # If we store the bias, it is a learnable parameter iff it is # learnable by default in the layer, which is only true if it # exists. if self.stores_bias: self._bias = torch.nn.Parameter(self.conv_layer.bias.detach()) else: self._bias = zero_volume_tensor() # This does not always exist, but when it does we can copy the # property. if self.receives_bias: self._bias.requires_grad = self.conv_layer.bias.requires_grad # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2 new_bias = self.conv_layer.bias.detach() * 0 new_bias.requires_grad = self.conv_layer.bias.requires_grad del self.conv_layer.bias self.conv_layer.bias = new_bias # Now we need to share the kernel structure. The size of the kernel # is always the spatial dimensions. self.conv_kernel_size = None self.conv_stride = None self.conv_padding = None self.conv_dilation = None if P_union.rank == 0: self.conv_kernel_size = np.array(self.conv_layer.kernel_size, dtype=np.int) self.conv_stride = np.array(self.conv_layer.stride, dtype=np.int) self.conv_padding = np.array(self.conv_layer.padding, dtype=np.int) self.conv_dilation = np.array(self.conv_layer.dilation, dtype=np.int) self.conv_kernel_size = P_union.broadcast_data(self.conv_kernel_size, root=0) self.conv_stride = P_union.broadcast_data(self.conv_stride, root=0) self.conv_padding = P_union.broadcast_data(self.conv_padding, root=0) self.conv_dilation = P_union.broadcast_data(self.conv_dilation, root=0) # We need the halo shape, and other info, to fully populate the pad, # halo exchange, and unpad layers. For pad and unpad, we defer their # construction to the pre-forward hook. self.pad_layer = None self.unpad_layer = None # We need to be able to remove some data from the input to the conv # layer. self.needed_slices = None # For the halo layer we also defer construction, so that we can have # the halo shape for the input. The halo will allocate its own # buffers, but it needs this information at construction to be able # to do this in the pre-forward hook. self.halo_layer = None # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_shape = None self._input_requires_grad = None if P_w.active: self.w_broadcast = Broadcast(self.P_wr, self.P_w, preserve_batch=False) if self.receives_bias or self.stores_bias: self.b_broadcast = Broadcast(self.P_br, self.P_b, preserve_batch=False) self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True) self.y_sum_reduce = SumReduce(self.P_w, self.P_y, preserve_batch=True)
def forward(ctx, input, P_allreduce, input_tensor_structure, output_tensor_structure): r"""Forward function of distributed all-sum-reduction layer. This method implements the forward all-sum-reduction operation using the ``MPI_Iallreduce`` function on the communicator defined by ``P_allreduce``. When the current worker is inactive in the ``P_allreduce`` partition, it will output a zero-volume tensor. Parameters ---------- ctx : PyTorch context. input : `torch.tensor` Input tensor. P_allreduce : Partition Partition reduction happens within. input_tensor_structure : tuple Tuple containing properties of the input tensor (dimension, shape, requires_grad). output_tensor_structure : tuple Tuple containing properties of the output tensor (dimension, shape, requires_grad). Returns ------- output : Output tensor. """ device = input.device ctx.P_allreduce = P_allreduce ctx.input_tensor_structure = input_tensor_structure ctx.output_tensor_structure = output_tensor_structure ctx.device = device output = zero_volume_tensor(device=device) requests = [] # There is no need to specificy a root. if P_allreduce.active: numpy_dtype = torch_to_numpy_dtype_dict[ input_tensor_structure.dtype] reduced_data = np.zeros(input_tensor_structure.shape, dtype=numpy_dtype) input_numpy = input.detach().cpu().numpy() req = P_allreduce._comm.Iallreduce(input_numpy, reduced_data, op=MPI.SUM) requests.append(req) MPI.Request.Waitall(requests) # If we had to receive data, we need to tensorify it. if P_allreduce.active: output = torch.tensor( reduced_data, requires_grad=output_tensor_structure.requires_grad, device=device) return output
def forward(ctx, input, P_union, x_global_shape, P_x, in_data, in_buffers, P_y, out_data, out_buffers, preserve_batch, dtype): ctx.P_union = P_union ctx.x_global_shape = x_global_shape ctx.P_x = P_x ctx.in_data = in_data ctx.in_buffers = in_buffers ctx.P_y = P_y ctx.out_data = out_data ctx.out_buffers = out_buffers ctx.preserve_batch = preserve_batch ctx.dtype = dtype input_requires_grad = False # By design, P_x is always first in the union if P_union.active: if P_x.rank == 0: input_requires_grad = input.requires_grad P_union.comm.Bcast(np.array([1 if input_requires_grad else 0]), root=0) else: irg = np.array([0], dtype=np.int) P_union.comm.Bcast(irg, root=0) input_requires_grad = bool(irg[0] == 1) ctx.input_requires_grad = input_requires_grad requests = [] # Default everyone to output nothing if preserve_batch: output = zero_volume_tensor(input.shape[0]) else: output = zero_volume_tensor() # If I am getting data, recv my output parts recv_count = 0 if P_y.active: for (sl, sz, partner), buff in zip(out_data, out_buffers): if buff is not None: req = P_union.comm.Irecv(buff, source=partner, tag=111) requests.append(req) else: # We add this if there is no recv so that the indices of # the requests array match the indices of out_data and # out_buffers. requests.append(MPI.REQUEST_NULL) recv_count += 1 # If I have data to share, pack and send my input parts send_count = 0 if P_x.active: input_numpy = input.detach().numpy() for (sl, sz, partner), buff in zip(in_data, in_buffers): if buff is not None: np.copyto(buff, input_numpy[tuple(sl)].ravel()) req = P_union.comm.Isend(buff, dest=partner, tag=111) requests.append(req) else: # We add this for symmetry, but don't really need it. requests.append(MPI.REQUEST_NULL) send_count += 1 # We do this after the sends so that they can get started before local # allocations. if P_y.active: index = P_y.index y_local_shape = compute_subshape(P_y.shape, index, x_global_shape) # TODO(#25): The dtype should not be fixed, but correcting this is # a thing that needs to be resolved globally. output = np.zeros(y_local_shape, dtype=dtype) # Unpack the received data as it arrives completed_count = 0 while (completed_count < len(requests)): status = MPI.Status() index = MPI.Request.Waitany(requests, status) # In MPI, we don't get the index out if the request is an # instance of MPI.REQUEST_NULL, instead MPI.UNDEFINED is returned. if P_y.active and index < recv_count and index != MPI.UNDEFINED: # Unpack my output parts sl, sz, partner = out_data[index] buff = out_buffers[index] if buff is not None: sh = output[tuple(sl)].shape np.copyto(output[tuple(sl)], buff.reshape(sh)) completed_count += 1 if P_y.active: output = torch.from_numpy(output) output.requires_grad = input_requires_grad return output
def test_conv_versus_pytorch(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, input_dimensions, x_global_shape, kernel_size, padding, stride, dilation, bias): import numpy as np import torch from torch.nn import Conv1d from torch.nn import Conv2d from torch.nn import Conv3d from distdl.backends.mpi.partition import MPIPartition from distdl.nn.conv_feature import DistributedFeatureConv1d from distdl.nn.conv_feature import DistributedFeatureConv2d from distdl.nn.conv_feature import DistributedFeatureConv3d from distdl.nn.repartition import Repartition from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) P_world._comm.Barrier() # Create the partitions P_0_base = P_world.create_partition_inclusive(np.arange(1)) P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_0 = P_0_base.create_cartesian_topology_partition([1] * len(P_x_shape)) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) scatter_layer_x = Repartition(P_0, P_x) scatter_layer_x = scatter_layer_x.to(device) scatter_layer_y = Repartition(P_0, P_x) scatter_layer_y = scatter_layer_y.to(device) gather_layer_x = Repartition(P_x, P_0) gather_layer_x = gather_layer_x.to(device) gather_layer_y = Repartition(P_x, P_0) gather_layer_y = gather_layer_y.to(device) # Create the layers if input_dimensions == 1: dist_layer_type = DistributedFeatureConv1d seq_layer_type = Conv1d elif input_dimensions == 2: dist_layer_type = DistributedFeatureConv2d seq_layer_type = Conv2d elif input_dimensions == 3: dist_layer_type = DistributedFeatureConv3d seq_layer_type = Conv3d dist_layer = dist_layer_type(P_x, in_channels=x_global_shape[1], out_channels=10, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, bias=bias) dist_layer = dist_layer.to(device) if P_0.active: seq_layer = seq_layer_type(in_channels=x_global_shape[1], out_channels=10, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, bias=bias) # set the weights of both layers to be the same seq_layer = seq_layer.to(device) weight = torch.rand_like(seq_layer.weight, device=device) seq_layer.weight.data = weight dist_layer.weight.data = weight if bias: bias_weight = torch.rand_like(seq_layer.bias, device=device) seq_layer.bias.data = bias_weight dist_layer.bias.data = bias_weight # Forward Input x_ref = zero_volume_tensor(device=device) x_ref.requires_grad = True dy_ref = zero_volume_tensor(device=device) # Construct the inputs to the forward and backward functions as well as the # the outputs of the sequential layer if P_0.active: x_ref = torch.randn(*x_global_shape, device=device) x_ref.requires_grad = True y_ref = seq_layer(x_ref) y_global_shape_calc = y_ref.shape dy_ref = torch.randn(*y_global_shape_calc, device=device) y_ref.backward(dy_ref) dx_ref = x_ref.grad # Ensure that the scatter is not part of the computation we are testing with torch.no_grad(): x = scatter_layer_x(x_ref.detach()) dy = scatter_layer_y(dy_ref.detach()) x.requires_grad = True y = dist_layer(x) y.backward(dy) dx = x.grad # Ensure that the gather is not part of the computation we are testing with torch.no_grad(): dx_comp = gather_layer_x(dx.detach()) y_comp = gather_layer_y(y.detach()) if P_0.active: # Set the absolute tolerance to ~sqrt(e_mach), or the default # Pytorch got their defaults from NumPy, but NumPy defaults to 64-bit # floats, not 32-bit floats as torch does. Consequently, the default # torch atol is actually tighter than one can expect from two fp-equal # floating point numbers. The NumPy default of 1e-8 is closer to # sqrt(e_mach) for 64-bit numbers. So we set the 32-bit tolerance to # a little tighter than sqrt(1e-7), 1e-5. if x_ref.dtype == torch.float64: atol = 1e-8 elif x_ref.dtype == torch.float32: atol = 1e-5 else: # torch default atol = 1e-8 # Test the result of each entry independently assert torch.allclose(y_ref, y_comp, atol=atol) assert torch.allclose(dx_ref, dx_comp, atol=atol) P_world.deactivate() P_0_base.deactivate() P_0.deactivate() P_x_base.deactivate() P_x.deactivate()
def backward(ctx, grad_output): P_union = ctx.P_union x_global_shape = ctx.x_global_shape P_x = ctx.P_x in_data = ctx.in_data in_buffers = ctx.in_buffers P_y = ctx.P_y out_data = ctx.out_data out_buffers = ctx.out_buffers preserve_batch = ctx.preserve_batch dtype = ctx.dtype input_requires_grad = ctx.input_requires_grad requests = [] # Default everyone to output None if preserve_batch: grad_input = zero_volume_tensor(grad_output.shape[0]) else: grad_input = zero_volume_tensor() # Recv my input parts recv_count = 0 if P_x.active: for (sl, sz, partner), buff in zip(in_data, in_buffers): if buff is not None: req = P_union.comm.Irecv(buff, source=partner, tag=113) requests.append(req) else: # We add this if there is no recv so that the indices of # the requests array match the indices of in_data and # in_buffers. requests.append(MPI.REQUEST_NULL) recv_count += 1 # Pack and send my input parts send_count = 0 if P_y.active: grad_output_numpy = grad_output.detach().numpy() for (sl, sz, partner), buff in zip(out_data, out_buffers): if buff is not None: np.copyto(buff, grad_output_numpy[tuple(sl)].ravel()) req = P_union.comm.Isend(buff, dest=partner, tag=113) requests.append(req) else: # We add this for symmetry, but don't really need it. requests.append(MPI.REQUEST_NULL) send_count += 1 if P_x.active: index = P_x.index x_local_shape = compute_subshape(P_x.shape, index, x_global_shape) # TODO(#25): The dtype should not be fixed, but correcting this is # a thing that needs to be resolved globally. grad_input = np.zeros(x_local_shape, dtype=dtype) # Unpack the received data as it arrives completed_count = 0 while (completed_count < len(requests)): status = MPI.Status() index = MPI.Request.Waitany(requests, status) # In MPI, we don't get the index out if the request is an # instance of MPI.REQUEST_NULL, instead MPI.UNDEFINED is returned. if P_x.active and index < recv_count and index != MPI.UNDEFINED: # Unpack my output parts sl, sz, partner = in_data[index] buff = in_buffers[index] if buff is not None: sh = grad_input[tuple(sl)].shape # This would normally be an add into the grad_input tensor # but we just created it, so a copy is sufficient. np.copyto(grad_input[tuple(sl)], buff.reshape(sh)) completed_count += 1 if P_x.active: grad_input = torch.from_numpy(grad_input) grad_input.requires_grad = input_requires_grad return grad_input, None, None, None, None, None, None, None, None, None, None