def __init__(self, P_x, P_y, P_w, in_features, out_features, bias=True): super(DistributedLinear, self).__init__() # P_x ~ 1 X P_fi self.P_x = P_x # P_y ~ 1 X P_fo self.P_y = P_y # P_w ~ P_fo X P_fi self.P_w = P_w self.bias = bias self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True) if self.P_w.active: local_in_features = compute_subshape(P_w.shape[1], P_w.index[1], in_features) local_out_features = compute_subshape(P_w.shape[0], P_w.index[0], out_features) # On column 0, use the specified bias, otherwise no bias to # prevent double counting bias = self.bias if (self.P_w.index[-1] == 0) else False self.sublinear = torch.nn.Linear(local_in_features[0], local_out_features[0], bias=bias) self.y_sum_reduce = SumReduce(self.P_w, self.P_y, transpose_src=True, preserve_batch=True)
def test_transpose_adjoint(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, x_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.transpose import DistributedTranspose from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) # The global tensor size is the same for x and y layer = DistributedTranspose(P_x, P_y, preserve_batch=False) # Forward Input x = zero_volume_tensor() if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.Tensor(np.random.randn(*x_local_shape)) x.requires_grad = True # Adjoint Input dy = zero_volume_tensor() if P_y.active: y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape) dy = torch.Tensor(np.random.randn(*y_local_shape)) # y = F @ x y = layer(x) # dx = F* @ dy y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy)
def _compute_halo_shape(self, shape, index, x_global_shape, kernel_size, stride, padding, dilation, require_nonnegative=True): x_global_shape = np.asarray(x_global_shape) x_local_shape = compute_subshape(shape, index, x_global_shape) x_local_start_index = compute_start_index(shape, index, x_global_shape) # formula from pytorch docs for maxpool y_global_shape = self._compute_out_shape(x_global_shape, kernel_size, stride, padding, dilation) y_local_shape = compute_subshape(shape, index, y_global_shape) y_local_start_index = compute_start_index(shape, index, y_global_shape) y_local_left_global_index = y_local_start_index x_local_left_global_index_needed = self._compute_min_input_range(y_local_left_global_index, kernel_size, stride, padding, dilation) # Clamp to the boundary x_local_left_global_index_needed = np.maximum(np.zeros_like(x_global_shape), x_local_left_global_index_needed) y_local_right_global_index = y_local_start_index + y_local_shape - 1 x_local_right_global_index_needed = self._compute_max_input_range(y_local_right_global_index, kernel_size, stride, padding, dilation) # Clamp to the boundary x_local_right_global_index_needed = np.minimum(x_global_shape - 1, x_local_right_global_index_needed) # Compute the actual ghost values x_local_left_halo_shape = x_local_start_index - x_local_left_global_index_needed x_local_stop_index = x_local_start_index + x_local_shape - 1 x_local_right_halo_shape = x_local_right_global_index_needed - x_local_stop_index # Make sure the halos are always positive, so we get valid buffer shape if require_nonnegative: x_local_left_halo_shape = np.maximum(x_local_left_halo_shape, 0) x_local_right_halo_shape = np.maximum(x_local_right_halo_shape, 0) return np.hstack([x_local_left_halo_shape, x_local_right_halo_shape]).reshape(2, -1).T
def test_simple_conv2d_adjoint_weight(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.conv_feature import DistributedFeatureConv2d from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) x_global_shape = np.asarray(x_global_shape) layer = DistributedFeatureConv2d(P_x, in_channels=x_global_shape[1], out_channels=10, kernel_size=[3, 3], bias=False) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.randn(*x_local_shape) x.requires_grad = True y = layer(x) dy = zero_volume_tensor(x_global_shape[0]) if P_x.active: dy = torch.randn(*y.shape) y.backward(dy) W = zero_volume_tensor() dW = zero_volume_tensor() if P_x.active: W = layer.weight.detach() dW = layer.weight.grad.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, W, dW, y, dy) P_world.deactivate() P_x_base.deactivate() P_x.deactivate()
def test_linear_adjoint_input(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, P_w_ranks, P_w_shape, x_global_shape, y_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.linear import DistributedLinear from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) P_w_base = P_world.create_partition_inclusive(P_w_ranks) P_w = P_w_base.create_cartesian_topology_partition(P_w_shape) x_global_shape = np.asarray(x_global_shape) y_global_shape = np.asarray(y_global_shape) layer = DistributedLinear(P_x, P_y, P_w, x_global_shape[1], y_global_shape[1], bias=False) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.Tensor(np.random.randn(*x_local_shape)) x.requires_grad = True y = layer(x) dy = zero_volume_tensor(x_global_shape[0]) if P_y.active: dy = torch.Tensor(np.random.randn(*y.shape)) y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy)
def test_excepts_mismatched_nondivisible_tensor(barrier_fence_fixture, comm_split_fixture): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.repartition import Repartition from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # A tensor with size 1 in a dimension cannot be partitioned in that # dimension. (See last dimension of output and tensor.) in_shape = (1, 4, 1, 1) out_shape = (1, 1, 1, 2) x_global_shape = np.array([1, 16, 5, 1]) in_size = np.prod(in_shape) out_size = np.prod(out_shape) # Create the partitions P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size)) P_x = P_x_base.create_cartesian_topology_partition(in_shape) P_y_base = P_world.create_partition_inclusive( np.arange(P_world.size - out_size, P_world.size)) P_y = P_y_base.create_cartesian_topology_partition(out_shape) with pytest.raises(ValueError) as e_info: # noqa: F841 layer = Repartition(P_x, P_y) layer = layer.to(device) # Forward Input x = zero_volume_tensor(device=device) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.randn(*x_local_shape, device=device) x.requires_grad = True layer(x) P_world.deactivate() P_x_base.deactivate() P_x.deactivate() P_y_base.deactivate() P_y.deactivate()
def test_simple_conv2d_shape(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape, y_local_shape, padding): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.conv_feature import DistributedFeatureConv2d from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) x_global_shape = np.asarray(x_global_shape) layer = DistributedFeatureConv2d(P_x, in_channels=x_global_shape[1], out_channels=10, kernel_size=[3, 3], padding=padding, bias=False) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.zeros(*x_local_shape) x.requires_grad = True y = layer(x) if P_x.active: assert(np.array_equal(np.array(y.shape), np.asarray(y_local_shape))) P_world.deactivate() P_x_base.deactivate() P_x.deactivate()
def __init__(self, P_x, P_y, P_w, in_features, out_features, bias=True): super(DistributedLinear, self).__init__() # P_x ~ 1 X P_fi self.P_x = P_x # P_y ~ 1 X P_fo self.P_y = P_y # P_w ~ P_fo X P_fi self.P_w = P_w # Bias flag self.bias = bias # Broadcast layer in the x-tensor self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True) # Each worker in P_W computes its own portion of the weight tensor and then # stores its own PyTorch Linear layer. Only the 0th column of the tensor # also stores a bias. if self.P_w.active: local_in_features = compute_subshape(P_w.shape[1], P_w.index[1], in_features) local_out_features = compute_subshape(P_w.shape[0], P_w.index[0], out_features) # On column 0, use the specified bias, otherwise no bias to # prevent double counting bias = self.bias if (self.P_w.index[-1] == 0) else False self.sublinear = torch.nn.Linear(local_in_features[0], local_out_features[0], bias=bias) # Sum-reduce layer to get the y-tensor self.y_sum_reduce = SumReduce(self.P_w, self.P_y, transpose_src=True, preserve_batch=True)
def test_excepts_mismatched_output_partition_tensor(barrier_fence_fixture, comm_split_fixture): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.transpose import DistributedTranspose from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Output partition rank must match tensor rank in_shape = (4, 1, 1) out_shape = (1, 1, 1, 2) x_global_shape = np.array([16, 5, 5]) in_size = np.prod(in_shape) out_size = np.prod(out_shape) # Create the partitions P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size)) P_x = P_x_base.create_cartesian_topology_partition(in_shape) P_y_base = P_world.create_partition_inclusive(np.arange(P_world.size-out_size, P_world.size)) P_y = P_y_base.create_cartesian_topology_partition(out_shape) with pytest.raises(ValueError) as e_info: # noqa: F841 layer = DistributedTranspose(P_x, P_y) # Forward Input x = zero_volume_tensor() if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.Tensor(np.random.randn(*x_local_shape)) x.requires_grad = True layer(x)
def test_linear_adjoint_bias(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, P_w_ranks, P_w_shape, x_global_shape, y_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.linear import DistributedLinear from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) P_w_base = P_world.create_partition_inclusive(P_w_ranks) P_w = P_w_base.create_cartesian_topology_partition(P_w_shape) x_global_shape = np.asarray(x_global_shape) y_global_shape = np.asarray(y_global_shape) layer = DistributedLinear(P_x, P_y, P_w, x_global_shape[1], y_global_shape[1], bias=True) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) # For this test, we are only testing to see if the adjoint works # correctly for the bias term. But the adjoint test only works on the # Jacobian of the linear layer. The Jacobian block for b is 0 for x and # W, so killing x makes the forward operator equal to its Jacobian and # we can test to see that adjoint is computed correctly. x = torch.zeros(*x_local_shape) x.requires_grad = True y = layer(x) dy = zero_volume_tensor(x_global_shape[0]) if P_y.active: dy = torch.Tensor(np.random.randn(*y.shape)) y.backward(dy) b = zero_volume_tensor() db = zero_volume_tensor() if P_w.active and P_w.index[-1] == 0: b = layer.sublinear.bias.detach() db = layer.sublinear.bias.grad.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, b, db, y, dy)
def _distdl_module_setup(self, input): r"""Transpose module setup function. Constructs the necessary buffers and meta information about outbound and inbound copies to each worker. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ self._distdl_is_setup = True self._input_tensor_structure.fill_from_tensor(input[0]) # If we are not an active worker, do nothing. if not self.P_union.active: return self.input_tensor_structure = TensorStructure(input[0]) self.global_input_tensor_structure = \ self._distdl_backend.assemble_global_tensor_structure(self.input_tensor_structure, self.P_x, self.P_union) x_global_shape = self.global_input_tensor_structure.shape if self.P_y.active: self.output_tensor_structure.shape = compute_subshape( self.P_y.shape, self.P_y.index, x_global_shape) tensor_dim = len(x_global_shape) if len(self.P_x_shape) != tensor_dim: raise ValueError( f"Input partition mush have same dimension " f"({len(self.P_x_shape)}) as input tensor rank ({tensor_dim})." ) if len(self.P_y_shape) != tensor_dim: raise ValueError( f"Output partition mush have same dimension " f"({len(self.P_y_shape)}) as input tensor rank ({tensor_dim})." ) if 1 in x_global_shape[x_global_shape != self.P_y_shape]: raise ValueError( f"Input tensor must not be size 1 " f"({x_global_shape}) in a dimension where " f"output partition is other than 1 ({self.P_y_shape}).") # Get the collective input lengths and origins. This may be load # balanced or it may not be. Therefore we will always assume is is # not load balanced and just build the subshape tensor manually. # This output is needed everywhere so it goes to P_union. compute_subtensor_shapes_unbalanced = \ self._distdl_backend.tensor_decomposition.compute_subtensor_shapes_unbalanced x_subtensor_shapes = compute_subtensor_shapes_unbalanced( self.input_tensor_structure, self.P_x, self.P_union) # Get the collective output lengths and origins. This will always be # load balanced, so we can infer the subshape tensor from the global # tensor shape and the shape of P_y. At this point, every worker in # P_union has both of these pieces of information, so we can build it # with no communication. y_subtensor_shapes = compute_subtensor_shapes_balanced( self.global_input_tensor_structure, self.P_y_shape) # Given all subtensor shapes, we can compute the start and stop indices # for each partition. x_subtensor_start_indices = compute_subtensor_start_indices( x_subtensor_shapes) x_subtensor_stop_indices = compute_subtensor_stop_indices( x_subtensor_shapes) y_subtensor_start_indices = compute_subtensor_start_indices( y_subtensor_shapes) y_subtensor_stop_indices = compute_subtensor_stop_indices( y_subtensor_shapes) # We only need to move data to the output partition if we actually # have input data. It is possible to have both input and output data, # either input or output data, or neither. Hence the active guard. if self.P_x.active: x_slice = tuple([slice(i, i + 1) for i in self.P_x.index] + [slice(None)]) x_start_index = x_subtensor_start_indices[x_slice].squeeze() x_stop_index = x_subtensor_stop_indices[x_slice].squeeze() # Compute our overlaps for each output subpartition. for rank, P_y_index in enumerate(range_index(self.P_y_shape)): y_slice = tuple([slice(i, i + 1) for i in P_y_index] + [slice(None)]) y_start_index = y_subtensor_start_indices[y_slice].squeeze() y_stop_index = y_subtensor_stop_indices[y_slice].squeeze() sl = compute_subtensor_intersection_slice( x_start_index, x_stop_index, y_start_index, y_stop_index) if sl is not None: sh = compute_nd_slice_shape(sl) # If it is a self-copy, mark it so we don't have to create # a potentially large buffer if self.P_y.active and np.all(P_y_index == self.P_y.index): partner = "self" # Otherwise, reverse the mapping to get the output # partner's rank in the common partition. else: partner = np.where(self.P_y_ranks == rank)[0][0] self.P_x_to_y_overlaps.append((sl, sh, partner)) else: self.P_x_to_y_overlaps.append((None, None, None)) # We only need to obtain data from the input partition if we actually # have output data. if self.P_y.active: y_slice = tuple([slice(i, i + 1) for i in self.P_y.index] + [slice(None)]) y_start_index = y_subtensor_start_indices[y_slice].squeeze() y_stop_index = y_subtensor_stop_indices[y_slice].squeeze() # Compute our overlaps for each input subpartition. for rank, P_x_index in enumerate(range_index(self.P_x_shape)): x_slice = tuple([slice(i, i + 1) for i in P_x_index] + [slice(None)]) x_start_index = x_subtensor_start_indices[x_slice].squeeze() x_stop_index = x_subtensor_stop_indices[x_slice].squeeze() sl = compute_subtensor_intersection_slice( y_start_index, y_stop_index, x_start_index, x_stop_index) if sl is not None: sh = compute_nd_slice_shape(sl) # If it is a self-copy, mark it so we don't have to create # a potentially large buffer if self.P_x.active and np.all(P_x_index == self.P_x.index): partner = "self" # Otherwise, reverse the mapping to get the output # partner's rank in the common partition. else: partner = np.where(self.P_x_ranks == rank)[0][0] self.P_y_to_x_overlaps.append((sl, sh, partner)) else: self.P_y_to_x_overlaps.append((None, None, None)) buffs = self.allocate_transpose_buffers( self.buffer_manager, self.P_x_to_y_overlaps, self.P_y_to_x_overlaps, self.global_input_tensor_structure.dtype) self.P_x_to_y_buffers = buffs[0] self.P_y_to_x_buffers = buffs[1]
def _compute_exchange_info(self, x_global_shape, kernel_size, stride, padding, dilation, partition_active, partition_shape, partition_index): if not partition_active: return None, None, None, None dim = len(partition_shape) x_global_shape = np.atleast_1d(x_global_shape) kernel_size = np.atleast_1d(kernel_size) stride = np.atleast_1d(stride) padding = np.atleast_1d(padding) dilation = np.atleast_1d(dilation) def compute_lpad_length(array): return len(x_global_shape) - len(array) kernel_size = np.pad(kernel_size, pad_width=(compute_lpad_length(kernel_size), 0), mode='constant', constant_values=1) stride = np.pad(stride, pad_width=(compute_lpad_length(stride), 0), mode='constant', constant_values=1) padding = np.pad(padding, pad_width=(compute_lpad_length(padding), 0), mode='constant', constant_values=0) dilation = np.pad(dilation, pad_width=(compute_lpad_length(dilation), 0), mode='constant', constant_values=1) halo_shape = self._compute_halo_shape(partition_shape, partition_index, x_global_shape, kernel_size, stride, padding, dilation) recv_buffer_shape = halo_shape.copy() send_buffer_shape = np.zeros_like(halo_shape) for i in range(dim): lindex = [x - 1 if j == i else x for j, x in enumerate(partition_index)] nhalo = self._compute_halo_shape(partition_shape, lindex, x_global_shape, kernel_size, stride, padding, dilation) # If I have a left neighbor, my left send buffer size is my left # neighbor's right halo size if(lindex[i] > -1): send_buffer_shape[i, 0] = nhalo[i, 1] rindex = [x + 1 if j == i else x for j, x in enumerate(partition_index)] nhalo = self._compute_halo_shape(partition_shape, rindex, x_global_shape, kernel_size, stride, padding, dilation) # If I have a right neighbor, my right send buffer size is my right # neighbor's left halo size if(rindex[i] < partition_shape[i]): send_buffer_shape[i, 1] = nhalo[i, 0] x_local_shape = compute_subshape(partition_shape, partition_index, x_global_shape) halo_shape_with_negatives = self._compute_halo_shape(partition_shape, partition_index, x_global_shape, kernel_size, stride, padding, dilation, require_nonnegative=False) needed_ranges = self._compute_needed_ranges(x_local_shape, halo_shape_with_negatives) halo_shape = halo_shape.astype(int) needed_ranges = needed_ranges.astype(int) return halo_shape, recv_buffer_shape, send_buffer_shape, needed_ranges
def forward(ctx, input, P_union, x_global_shape, P_x, in_data, in_buffers, P_y, out_data, out_buffers, preserve_batch, dtype): ctx.P_union = P_union ctx.x_global_shape = x_global_shape ctx.P_x = P_x ctx.in_data = in_data ctx.in_buffers = in_buffers ctx.P_y = P_y ctx.out_data = out_data ctx.out_buffers = out_buffers ctx.preserve_batch = preserve_batch ctx.dtype = dtype input_requires_grad = False # By design, P_x is always first in the union if P_union.active: if P_x.rank == 0: input_requires_grad = input.requires_grad P_union.comm.Bcast(np.array([1 if input_requires_grad else 0]), root=0) else: irg = np.array([0], dtype=np.int) P_union.comm.Bcast(irg, root=0) input_requires_grad = bool(irg[0] == 1) ctx.input_requires_grad = input_requires_grad requests = [] # Default everyone to output nothing if preserve_batch: output = zero_volume_tensor(input.shape[0]) else: output = zero_volume_tensor() # If I am getting data, recv my output parts recv_count = 0 if P_y.active: for (sl, sz, partner), buff in zip(out_data, out_buffers): if buff is not None: req = P_union.comm.Irecv(buff, source=partner, tag=111) requests.append(req) else: # We add this if there is no recv so that the indices of # the requests array match the indices of out_data and # out_buffers. requests.append(MPI.REQUEST_NULL) recv_count += 1 # If I have data to share, pack and send my input parts send_count = 0 if P_x.active: input_numpy = input.detach().numpy() for (sl, sz, partner), buff in zip(in_data, in_buffers): if buff is not None: np.copyto(buff, input_numpy[tuple(sl)].ravel()) req = P_union.comm.Isend(buff, dest=partner, tag=111) requests.append(req) else: # We add this for symmetry, but don't really need it. requests.append(MPI.REQUEST_NULL) send_count += 1 # We do this after the sends so that they can get started before local # allocations. if P_y.active: index = P_y.index y_local_shape = compute_subshape(P_y.shape, index, x_global_shape) # TODO(#25): The dtype should not be fixed, but correcting this is # a thing that needs to be resolved globally. output = np.zeros(y_local_shape, dtype=dtype) # Unpack the received data as it arrives completed_count = 0 while (completed_count < len(requests)): status = MPI.Status() index = MPI.Request.Waitany(requests, status) # In MPI, we don't get the index out if the request is an # instance of MPI.REQUEST_NULL, instead MPI.UNDEFINED is returned. if P_y.active and index < recv_count and index != MPI.UNDEFINED: # Unpack my output parts sl, sz, partner = out_data[index] buff = out_buffers[index] if buff is not None: sh = output[tuple(sl)].shape np.copyto(output[tuple(sl)], buff.reshape(sh)) completed_count += 1 if P_y.active: output = torch.from_numpy(output) output.requires_grad = input_requires_grad return output
def test_general_conv1d_adjoint_bias(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, P_w_ranks, P_w_shape, x_global_shape): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.conv_general import DistributedGeneralConv1d from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) P_w_base = P_world.create_partition_inclusive(P_w_ranks) P_w = P_w_base.create_cartesian_topology_partition(P_w_shape) x_global_shape = np.asarray(x_global_shape) layer = DistributedGeneralConv1d(P_x, P_y, P_w, in_channels=x_global_shape[1], out_channels=10, kernel_size=[3], bias=True) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.zeros(*x_local_shape) x.requires_grad = True y = layer(x) dy = zero_volume_tensor(x_global_shape[0]) if P_y.active: dy = torch.randn(*y.shape) y.backward(dy) b = zero_volume_tensor() db = zero_volume_tensor() if P_w.active and layer.stores_bias: b = layer.bias.detach() db = layer.bias.grad.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, b, db, y, dy) P_world.deactivate() P_x_base.deactivate() P_x.deactivate() P_y_base.deactivate() P_y.deactivate() P_w_base.deactivate() P_w.deactivate()
def test_repartition_dtype(barrier_fence_fixture, comm_split_fixture, dtype, test_backward, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, x_global_shape): import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.repartition import Repartition from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) # The global tensor size is the same for x and y layer = Repartition(P_x, P_y, preserve_batch=False) layer = layer.to(device) # Forward Input x = zero_volume_tensor(dtype=dtype, device=device) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = 10 * torch.randn(*x_local_shape, device=device).to(dtype) x.requires_grad = test_backward # y = F @ x y = layer(x) if P_y.active: assert y.dtype == dtype if test_backward: # Adjoint Input dy = zero_volume_tensor(dtype=dtype, device=device) if P_y.active: y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape) dy = 10 * torch.randn(*y_local_shape, device=device).to(dtype) # dx = F* @ dy y.backward(dy) dx = x.grad if P_x.active: assert dx.dtype == dtype P_world.deactivate() P_x_base.deactivate() P_x.deactivate() P_y_base.deactivate() P_y.deactivate()
def test_halo_exchange_adjoint(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape, dtype, kernel_size, stride, padding, dilation, MockKernelStyle): import numpy as np import torch import torch.nn.functional as F from distdl.backends.mpi.partition import MPIPartition from distdl.nn.halo_exchange import HaloExchange from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import distdl_padding_to_torch_padding from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) x_global_shape = np.asarray(x_global_shape) kernel_size = np.asarray(kernel_size) stride = np.asarray(stride) padding = np.asarray(padding) dilation = np.asarray(dilation) halo_shape = None recv_buffer_shape = None send_buffer_shape = None if P_x.active: mockup_layer = MockKernelStyle() exchange_info = mockup_layer._compute_exchange_info( x_global_shape, kernel_size, stride, padding, dilation, P_x.active, P_x.shape, P_x.index) halo_shape = exchange_info[0] recv_buffer_shape = exchange_info[1] send_buffer_shape = exchange_info[2] halo_layer = HaloExchange(P_x, halo_shape, recv_buffer_shape, send_buffer_shape) halo_layer = halo_layer.to(device) x = zero_volume_tensor(x_global_shape[0], device=device) dy = zero_volume_tensor(x_global_shape[0], device=device) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) padding = distdl_padding_to_torch_padding(halo_shape) x = torch.randn(*x_local_shape, device=device).to(dtype) # Pad the input with the halo space. We are only testing the behavior of # the halo exchange so the input must be padded before we can do anything. x = F.pad(x, pad=padding, mode="constant", value=0) # dy is also padded, but we wanted it to start with data inside it. dy = torch.randn(*x.shape, device=device).to(dtype) x.requires_grad = True # Halo Exchange (both fwd and adj) is in-place. So, we copy the input # data and save the original for the adjoint test. Because it is in-place, # the clones themselves are modified. This also prevents issues with us # in-place operations on leaf-nodes. x_clone = x.clone() dy_clone = dy.clone() # x_clone is be modified in place by halo_layer, but we assign y to # reference it for clarity. y and x_clone are the same object. y = halo_layer(x_clone) # dy_clone is modified in place by halo_layer-adjoint, but we assign dx to # reference it for clarity. dx and dy_clone are the same object. # dx is not in the grad field as you might expect because the operation is # in-place. y.backward(dy_clone) dx = dy_clone x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy) P_world.deactivate() P_x_base.deactivate() P_x.deactivate()
def __init__(self, P_x, P_y, P_w, in_channels, out_channels, kernel_size, stride=1, padding=0, padding_mode='zeros', dilation=1, groups=1, bias=True, buffer_manager=None): super(DistributedGeneralConvBase, self).__init__() # P_x is 1 x P_ci x P_d-1 x ... x P_0 self.P_x = P_x # P_y is 1 x P_co x P_d-1 x ... x P_0 self.P_y = P_y # P_w is P_co x P_ci x P_d-1 x ... x P_0 self.P_w = P_w # Back-end specific buffer manager for economic buffer allocation if buffer_manager is None: buffer_manager = self._distdl_backend.BufferManager() elif type(buffer_manager) is not self._distdl_backend.BufferManager: raise ValueError("Buffer manager type does not match backend.") self.buffer_manager = buffer_manager # Even inactive workers need some partition union self.P_union = self._distdl_backend.Partition() if not (self.P_x.active or self.P_y.active or self.P_w.active): return self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = self._expand_parameter(kernel_size) self.stride = self._expand_parameter(stride) self.padding = self._expand_parameter(padding) self.padding_mode = padding_mode self.dilation = self._expand_parameter(dilation) self.groups = groups self.use_bias = bias # This guarantees that P_union rank 0 has the kernel size, stride, # padding, and dilation factors P_union_temp = P_w.create_partition_union(P_x) self.P_union = P_union_temp.create_partition_union(P_y) # Release the temporary resources P_union_temp.deactivate() # Ensure that all workers have the full size and structure of P_w P_w_shape = None if self.P_union.rank == 0: P_w_shape = np.array(P_w.shape, dtype=np.int) P_w_shape = self.P_union.broadcast_data(P_w_shape, root=0) P_co = P_w_shape[0] P_ci = P_w_shape[1] P_channels = [P_co, P_ci] # Ensure that P_x and P_w are correctly aligned. We also produce a # new P_x that is shaped like 1 x P_ci x P_d-1 x ... x P_0, to assist # with broadcasts. P_x_new_shape = [] if self.P_x.active: if(np.any(P_x.shape[2:] != P_w_shape[2:])): raise ValueError("Spatial components of P_x and P_w must match.") if P_w_shape[1] != P_x.shape[1]: raise ValueError("Index 2 of P_w dimension must match input channel partition.") P_x_new_shape = list(P_x.shape) P_x_new_shape.insert(1, 1) # Currently a hack, removing the batch dimension because P_w does # not have one. This is OK because we assume there are no partitions # in the batch dimension. P_x_new_shape = np.asarray(P_x_new_shape[1:], dtype=int) # For the purposes of this layer, we re-cast P_x to have the extra # dimension. This has no impact outside of the layer or on the results. self.P_x = self.P_x.create_cartesian_topology_partition(P_x_new_shape) # Ensure that P_y and P_w are correctly aligned. We also produce a # new P_y that is shaped like 1 x P_ci x P_d-1 x ... x P_0, to assist # with broadcasts. P_y_new_shape = [] if self.P_y.active: if(np.any(P_y.shape[2:] != P_w_shape[2:])): raise ValueError("Spatial components of P_y and P_w must match.") if P_w_shape[0] != P_y.shape[1]: raise ValueError("Index 1 of P_w dimension must match output channel partition.") P_y_new_shape = list(P_y.shape) P_y_new_shape.insert(2, 1) # Currently a hack, removing the batch dimension because P_w does # not have one. This is OK because we assume there are no partitions # in the batch dimension. P_y_new_shape = np.asarray(P_y_new_shape[1:], dtype=int) # For the purposes of this layer, we re-cast P_x to have the extra # dimension. This has no impact outside of the layer or on the results. self.P_y = self.P_y.create_cartesian_topology_partition(P_y_new_shape) P_spatial = P_w_shape[2:] self.serial = self.P_w.size == 1 if self.serial: self.conv_layer = self.TorchConvType(in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, padding_mode=self.padding_mode, dilation=self.dilation, groups=self.groups, bias=self.use_bias) self.weight = self.conv_layer.weight self.bias = self.conv_layer.bias return # Need to figure out any padding necessary to handle global padding. # This is only on the input tensor. The convolution will not use # any implicit padding, so the work partition does not need it. if self.P_x.active: dims = len(self.P_x.shape) # We will be using global padding to compute local padding, # so expand it to a numpy array global_padding = np.pad(self.padding, pad_width=(dims-len(self.padding), 0), mode='constant', constant_values=0) self.global_padding = global_padding pad_left_right = self.global_padding.reshape((dims, 1)) + np.zeros((dims, 2), dtype=np.int) self.local_padding = self._compute_local_padding(pad_left_right) # Workers can either store the learnable weights and bias, or they # need copies of it. self.receives_weight = False self.stores_weight = False self.receives_bias = False self.stores_bias = False # Determine root partitions, initialize weights there if self.P_w.active: # All of P_w always receives the weight self.receives_weight = True # This subset is taken to be the origin of the spartial component w_root_subset = [] for i, c in enumerate(range_index(P_w.shape)): c = np.asarray(c) # Find the P_co x P_ci x 1 x ... x 1 subset to store the weights if np.all(c[2:] == 0): w_root_subset.append(i) P_wr_base = self.P_w.create_partition_inclusive(w_root_subset) # ones are needed so the broadcast will work self.P_wr = P_wr_base.create_cartesian_topology_partition([P_co, P_ci] + [1]*len(P_spatial)) self.stores_weight = self.P_wr.active # Release temporary resources P_wr_base.deactivate() b_subset = [] for i, c in enumerate(range_index(P_w.shape)): c = np.asarray(c) # Find the P_co x 1 x P_0 x ... x P_D-1 subset that needs # biases in its calculation. This is everywhere that the input # channels is rank 0. if c[1] == 0: b_subset.append(i) P_b_base = self.P_w.create_partition_inclusive(b_subset) self.P_b = P_b_base.create_cartesian_topology_partition([P_co] + [1] + list(P_spatial)) self.receives_bias = self.P_b.active and self.use_bias # Release temporary resources P_b_base.deactivate() # Now find the subset of _that_ which actually stores the # learnable parameter. b_root_subset = [] for i, c in enumerate(range_index(P_w.shape)): c = np.asarray(c) # Find the P_co x 1 x 1 x ... x 1 subset to store the biases if np.all(c[1:] == 0): b_root_subset.append(i) P_br_base = self.P_w.create_partition_inclusive(b_root_subset) # ones are needed so the broadcast will work self.P_br = P_br_base.create_cartesian_topology_partition([P_co] + [1] + [1]*len(P_spatial)) self.stores_bias = self.P_br.active and self.use_bias # Release temporary resources P_br_base.deactivate() # Correct the input arguments based on local properties # This ensures that the in and out channels are correctly shared. local_co, local_ci = compute_subshape(P_channels, P_w.index[0:2], [out_channels, in_channels]) self.conv_layer = self.TorchConvType(in_channels=local_ci, out_channels=local_co, kernel_size=self.kernel_size, stride=self.stride, padding=0, padding_mode='zeros', dilation=self.dilation, groups=groups, bias=self.receives_bias) # If we store the weight it is a learnable parameter iff it is # learnable by default in the layer, which it is. if self.stores_weight: self.weight = torch.nn.Parameter(self.conv_layer.weight.detach()) else: self.register_buffer('weight', zero_volume_tensor()) # This always exists so we can copy the property self.weight.requires_grad = self.conv_layer.weight.requires_grad # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2 new_weight = self.conv_layer.weight.detach() * 0 new_weight.requires_grad = self.conv_layer.weight.requires_grad del self.conv_layer.weight self.conv_layer.weight = new_weight # If we store the bias, it is a learnable parameter iff it is # learnable by default in the layer, which is only true if it # exists. if self.stores_bias: self.bias = torch.nn.Parameter(self.conv_layer.bias.detach()) else: if self.use_bias: self.register_buffer('bias', zero_volume_tensor()) else: self.register_buffer('bias', None) # This does not always exist, but when it does we can copy the # property. if self.receives_bias: self.bias.requires_grad = self.conv_layer.bias.requires_grad # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2 new_bias = self.conv_layer.bias.detach() * 0 new_bias.requires_grad = self.conv_layer.bias.requires_grad del self.conv_layer.bias self.conv_layer.bias = new_bias else: # Workers not in P_w don't have a weight or bias. self.register_buffer('weight', zero_volume_tensor()) if self.use_bias: self.register_buffer('bias', zero_volume_tensor()) else: self.register_buffer('bias', None) # Now we need to share the kernel structure. The size of the kernel # is always the spatial dimensions. self.conv_kernel_size = None self.conv_stride = None self.conv_padding = None self.conv_dilation = None # By construction, rank 0 of the union should always have all of this # information, because it will always construct a local conv layer. We # rely on the local conv layer to properly fill out this information # from the defaults. This info is required for all workers on the # input and output partitions because it is needed to construct the # halos. Rank 0 in the union shares it with everyone. if self.P_union.rank == 0: self.conv_kernel_size = np.array(self.conv_layer.kernel_size, dtype=np.int) self.conv_stride = np.array(self.conv_layer.stride, dtype=np.int) self.conv_padding = np.array(self.conv_layer.padding, dtype=np.int) self.conv_dilation = np.array(self.conv_layer.dilation, dtype=np.int) self.conv_kernel_size = self.P_union.broadcast_data(self.conv_kernel_size, root=0) self.conv_stride = self.P_union.broadcast_data(self.conv_stride, root=0) self.conv_padding = self.P_union.broadcast_data(self.conv_padding, root=0) self.conv_dilation = self.P_union.broadcast_data(self.conv_dilation, root=0) # We need to be able to remove some data from the input to the conv # layer but again need to defer. self.needed_slices = None # For the halo layer we also defer construction, so that we can have # the halo shape for the input. The halo will allocate its own # buffers, but it needs this information at construction to be able # to do this in the pre-forward hook. self.halo_layer = None # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_tensor_structure = TensorStructure() # Some layers, those that require no information about the input # tensor to setup, can be built now. if P_w.active: self.w_broadcast = Broadcast(self.P_wr, self.P_w, preserve_batch=False) if self.receives_bias or self.stores_bias: self.b_broadcast = Broadcast(self.P_br, self.P_b, preserve_batch=False) self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True) self.y_sum_reduce = SumReduce(self.P_w, self.P_y, preserve_batch=True)
def __init__(self, P_x, P_y, P_w, in_channels, out_channels, kernel_size, stride=1, padding=0, padding_mode='zeros', dilation=1, groups=1, bias=True, *args, **kwargs): super(DistributedChannelConvBase, self).__init__() # P_x is 1 x P_ci x 1 x ... x 1 self.P_x = P_x # P_y is 1 x P_co x 1 x ... x 1 self.P_y = P_y # P_w is P_co x P_ci x 1 x ... x 1 self.P_w = P_w # Even inactive workers need some partition union P_union = self._distdl_backend.Partition() if not (self.P_x.active or self.P_y.active or self.P_w.active): return self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = self._expand_parameter(kernel_size) self.stride = self._expand_parameter(stride) self.padding = self._expand_parameter(padding) self.padding_mode = padding_mode self.dilation = self._expand_parameter(dilation) self.groups = groups self.use_bias = bias # This guarantees that P_union rank 0 has the kernel size, stride, # padding, and dilation factors P_union_temp = P_w.create_partition_union(P_x) P_union = P_union_temp.create_partition_union(P_y) # Ensure that all workers have the full size and structure of P_w P_w_shape = None if P_union.rank == 0: P_w_shape = np.array(P_w.shape, dtype=np.int) P_w_shape = P_union.broadcast_data(P_w_shape, root=0) # Release the temporary resources P_union_temp.deactivate() P_union.deactivate() P_co = P_w_shape[0] P_ci = P_w_shape[1] P_channels = [P_co, P_ci] # Ensure that P_x and P_w are correctly aligned. We also produce a # new P_x that is shaped like 1 x P_ci x 1 x ... x 1, to assist with # broadcasts. P_x_new_shape = [] if self.P_x.active: if (np.any(P_x.shape[2:] != P_w_shape[2:])): raise ValueError( "Spatial components of P_x and P_w must match.") if (np.any(P_x.shape[2:] != np.ones(len(P_x.shape[2:])))): raise ValueError( "Spatial components of P_x must be 1 x ... x 1.") if P_w_shape[1] != P_x.shape[1]: raise ValueError( "Index 2 of P_w dimension must match input channel partition." ) P_x_new_shape = list(P_x.shape) P_x_new_shape.insert(1, 1) # Currently a hack, removing the batch dimension because P_w does # not have one. This is OK because we assume there are no partitions # in the batch dimension. P_x_new_shape = np.asarray(P_x_new_shape[1:], dtype=int) # For the purposes of this layer, we re-cast P_x to have the extra # dimension. This has no impact outside of the layer or on the results. self.P_x = self.P_x.create_cartesian_topology_partition(P_x_new_shape) # Ensure that P_y and P_w are correctly aligned. We also produce a # new P_y that is shaped like P_co x 1 x 1 x ... x 1, to assist with # broadcasts. P_y_new_shape = [] if self.P_y.active: if (np.any(P_y.shape[2:] != P_w_shape[2:])): raise ValueError( "Spatial components of P_y and P_w must match.") if (np.any(P_y.shape[2:] != np.ones(len(P_y.shape[2:])))): raise ValueError( "Spatial components of P_y must be 1 x ... x 1.") if P_w_shape[0] != P_y.shape[1]: raise ValueError( "Index 1 of P_w dimension must match output channel partition." ) P_y_new_shape = list(P_y.shape) P_y_new_shape.insert(2, 1) # Currently a hack, removing the batch dimension because P_w does # not have one. This is OK because we assume there are no partitions # in the batch dimension. P_y_new_shape = np.asarray(P_y_new_shape[1:], dtype=int) # For the purposes of this layer, we re-cast P_x to have the extra # dimension. This has no impact outside of the layer or on the results. self.P_y = self.P_y.create_cartesian_topology_partition(P_y_new_shape) self.serial = self.P_w.size == 1 if self.serial: self.conv_layer = self.TorchConvType( in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, padding_mode=self.padding_mode, dilation=self.dilation, groups=self.groups, bias=self.use_bias) self.weight = self.conv_layer.weight self.bias = self.conv_layer.bias return # Flag if the global bias is set self.global_bias = bias # Flags if current worker stores (part of) the bias locally. self.stores_bias = False if self.P_w.active: # Let the P_co column store the bias if it is to be used self.stores_bias = self.P_w.index[1] == 0 and self.use_bias # Correct the input arguments based on local properties # This ensures that the in and out channels are correctly shared. local_co, local_ci = compute_subshape(P_channels, P_w.index[0:2], [out_channels, in_channels]) self.conv_layer = self.TorchConvType( in_channels=local_ci, out_channels=local_co, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, padding_mode=self.padding_mode, dilation=self.dilation, groups=groups, bias=self.stores_bias) # Workers in P_w alias the conv layer to get their weight and perhaps # biases. Every other worker doesn't have a weight or bias. if self.P_w.active: self.weight = self.conv_layer.weight if self.stores_bias: self.bias = self.conv_layer.bias else: if self.use_bias: self.register_buffer('bias', zero_volume_tensor()) else: self.register_buffer('bias', None) else: self.register_buffer('weight', zero_volume_tensor()) if self.use_bias: self.register_buffer('bias', zero_volume_tensor()) else: self.register_buffer('bias', None) # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_tensor_structure = TensorStructure() self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True) self.y_sum_reduce = SumReduce(self.P_w, self.P_y, preserve_batch=True)
def test_repartition_identity(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape, balanced): import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.repartition import Repartition from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y = P_x # The global tensor size is the same for x and y layer = Repartition(P_x, P_y, preserve_batch=False) layer = layer.to(device) # Forward Input x = zero_volume_tensor(device=device) if P_x.active: if balanced: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) else: quotient = np.atleast_1d(x_global_shape) // np.atleast_1d( P_x_shape) remainder = np.atleast_1d(x_global_shape) % np.atleast_1d( P_x_shape) loc = np.where(P_x.index == 0) x_local_shape = quotient.copy() x_local_shape[loc] += remainder[loc] x = torch.randn(*x_local_shape, device=device) x.requires_grad = True # Adjoint Input dy = zero_volume_tensor(device=device) if P_y.active: y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape) dy = torch.randn(*y_local_shape, device=device) # y = F @ x y = layer(x) # In the balanced case, this should be a true identity, so there should # be no communication performed, just self-copies. if balanced: for sl, sz, p in layer.P_x_to_y_overlaps: assert p == "self" or (sl, sz, p) == (None, None, None) for sl, sz, p in layer.P_y_to_x_overlaps: assert p == "self" or (sl, sz, p) == (None, None, None) # dx = F* @ dy y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy) P_world.deactivate() P_x_base.deactivate() P_x.deactivate()
def __init__(self, P_x, P_y, P_w, in_channels=1, out_channels=1, bias=True, *args, **kwargs): super(DistributedGeneralConvBase, self).__init__() # P_x is 1 x P_ci x P_d-1 x ... x P_0 self.P_x = P_x # P_y is 1 x P_co x P_d-1 x ... x P_0 self.P_y = P_y # P_w is P_co x P_ci x P_d-1 x ... x P_0 self.P_w = P_w self.P_union = self._distdl_backend.Partition() if not (self.P_x.active or self.P_y.active or self.P_w.active): return # This guarantees that P_union rank 0 has the kernel size, stride, # padding, and dilation factors P_union = P_w.create_partition_union(P_x) P_union = P_union.create_partition_union(P_y) self.P_union = P_union P_w_shape = None if P_union.rank == 0: P_w_shape = np.array(P_w.shape, dtype=np.int) P_w_shape = P_union.broadcast_data(P_w_shape, root=0) P_co = P_w_shape[0] P_ci = P_w_shape[1] P_channels = [P_co, P_ci] P_x_new_shape = [] if self.P_x.active: if (np.any(P_x.shape[2:] != P_w_shape[2:])): raise ValueError( "Spatial components of P_x and P_w must match.") if P_w_shape[1] != P_x.shape[1]: raise ValueError( "Index 2 of P_w dimension must match input channel partition." ) P_x_new_shape = list(P_x.shape) P_x_new_shape.insert(1, 1) # Currently a hack, removing the batch dimension because P_w does # not have one. This is OK because we assume there are no partitions # in the batch dimension. P_x_new_shape = np.asarray(P_x_new_shape[1:], dtype=int) # For the purposes of this layer, we re-cast P_x to have the extra # dimension. This has no impact outside of the layer or on the results. self.P_x = self.P_x.create_cartesian_topology_partition(P_x_new_shape) P_y_new_shape = [] if self.P_y.active: if (np.any(P_y.shape[2:] != P_w_shape[2:])): raise ValueError( "Spatial components of P_y and P_w must match.") if P_w_shape[0] != P_y.shape[1]: raise ValueError( "Index 1 of P_w dimension must match output channel partition." ) P_y_new_shape = list(P_y.shape) P_y_new_shape.insert(2, 1) # Currently a hack, removing the batch dimension because P_w does # not have one. This is OK because we assume there are no partitions # in the batch dimension. P_y_new_shape = np.asarray(P_y_new_shape[1:], dtype=int) # For the purposes of this layer, we re-cast P_x to have the extra # dimension. This has no impact outside of the layer or on the results. self.P_y = self.P_y.create_cartesian_topology_partition(P_y_new_shape) P_spatial = P_w_shape[2:] self.serial = False if self.P_w.size == 1: self.serial = True self.conv_layer = self.TorchConvType(*args, **kwargs) return self.receives_weight = False self.stores_weight = False self.receives_bias = False self.stores_bias = False # Determine P_r, initialize weights there if self.P_w.active: # All of P_w always receives the weight self.receives_weight = True # This subset is taken to be the origin of the spartial component w_root_subset = [] for i, c in enumerate(range_index(P_w.shape)): c = np.asarray(c) # Find the P_co x P_ci x 1 x ... x 1 subset to store the weights if np.all(c[2:] == 0): w_root_subset.append(i) self.P_wr_base = self.P_w.create_partition_inclusive(w_root_subset) # ones are needed so the broadcast will work self.P_wr = self.P_wr_base.create_cartesian_topology_partition( [P_co, P_ci] + [1] * len(P_spatial)) self.stores_weight = self.P_wr.active b_subset = [] for i, c in enumerate(range_index(P_w.shape)): c = np.asarray(c) # Find the P_co x 1 x P_0 x ... x P_D-1 subset that needs biases in its calculation. # This is everywhere that the input channels is rank 0. if c[1] == 0: b_subset.append(i) self.P_b_base = self.P_w.create_partition_inclusive(b_subset) self.P_b = self.P_b_base.create_cartesian_topology_partition( [P_co] + [1] + list(P_spatial)) self.receives_bias = self.P_b.active and bias # Now find the subset of _that_ which actually stores the learnable parameter. b_root_subset = [] for i, c in enumerate(range_index(P_w.shape)): c = np.asarray(c) # Find the P_co x 1 x 1 x ... x 1 subset to store the biases if np.all(c[1:] == 0): b_root_subset.append(i) self.P_br_base = self.P_w.create_partition_inclusive(b_root_subset) # ones are needed so the broadcast will work self.P_br = self.P_br_base.create_cartesian_topology_partition( [P_co] + [1] + [1] * len(P_spatial)) self.stores_bias = self.P_br.active and bias # Correct the input arguments based on local properties local_kwargs = {} local_kwargs.update(kwargs) # Do this before checking serial so that the layer works properly # in the serial case local_channels = compute_subshape(P_channels, P_w.index[0:2], [out_channels, in_channels]) local_out_channels, local_in_channels = local_channels local_kwargs["in_channels"] = local_in_channels local_kwargs["out_channels"] = local_out_channels local_kwargs["bias"] = self.receives_bias self.conv_layer = self.TorchConvType(*args, **local_kwargs) # If we store the weight it is a learnable parameter iff it is # learnable by default in the layer, which it is. if self.stores_weight: self._weight = torch.nn.Parameter( self.conv_layer.weight.detach()) else: self._weight = zero_volume_tensor() # This always exists so we can copy the property self._weight.requires_grad = self.conv_layer.weight.requires_grad # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2 new_weight = self.conv_layer.weight.detach() * 0 new_weight.requires_grad = self.conv_layer.weight.requires_grad del self.conv_layer.weight self.conv_layer.weight = new_weight # If we store the bias, it is a learnable parameter iff it is # learnable by default in the layer, which is only true if it # exists. if self.stores_bias: self._bias = torch.nn.Parameter(self.conv_layer.bias.detach()) else: self._bias = zero_volume_tensor() # This does not always exist, but when it does we can copy the # property. if self.receives_bias: self._bias.requires_grad = self.conv_layer.bias.requires_grad # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2 new_bias = self.conv_layer.bias.detach() * 0 new_bias.requires_grad = self.conv_layer.bias.requires_grad del self.conv_layer.bias self.conv_layer.bias = new_bias # Now we need to share the kernel structure. The size of the kernel # is always the spatial dimensions. self.conv_kernel_size = None self.conv_stride = None self.conv_padding = None self.conv_dilation = None if P_union.rank == 0: self.conv_kernel_size = np.array(self.conv_layer.kernel_size, dtype=np.int) self.conv_stride = np.array(self.conv_layer.stride, dtype=np.int) self.conv_padding = np.array(self.conv_layer.padding, dtype=np.int) self.conv_dilation = np.array(self.conv_layer.dilation, dtype=np.int) self.conv_kernel_size = P_union.broadcast_data(self.conv_kernel_size, root=0) self.conv_stride = P_union.broadcast_data(self.conv_stride, root=0) self.conv_padding = P_union.broadcast_data(self.conv_padding, root=0) self.conv_dilation = P_union.broadcast_data(self.conv_dilation, root=0) # We need the halo shape, and other info, to fully populate the pad, # halo exchange, and unpad layers. For pad and unpad, we defer their # construction to the pre-forward hook. self.pad_layer = None self.unpad_layer = None # We need to be able to remove some data from the input to the conv # layer. self.needed_slices = None # For the halo layer we also defer construction, so that we can have # the halo shape for the input. The halo will allocate its own # buffers, but it needs this information at construction to be able # to do this in the pre-forward hook. self.halo_layer = None # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_shape = None self._input_requires_grad = None if P_w.active: self.w_broadcast = Broadcast(self.P_wr, self.P_w, preserve_batch=False) if self.receives_bias or self.stores_bias: self.b_broadcast = Broadcast(self.P_br, self.P_b, preserve_batch=False) self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True) self.y_sum_reduce = SumReduce(self.P_w, self.P_y, preserve_batch=True)
layer = Repartition(P_x, P_y, preserve_batch=False) # Setup the input tensor. Any worker in P_x will generate its part of the # input tensor. Any worker not in P_x will have a zero-volume tensor. # # Input tensor will be (on a 1 x 1 partition): # [ [ 1 1 1 1 1 ] # [ 1 1 1 1 1 ] # [ 1 1 1 1 1 ] # [ 1 1 1 1 1 ] # [ 1 1 1 1 1 ] # [ 1 1 1 1 1 ] # [ 1 1 1 1 1 ] ] x = zero_volume_tensor() if P_x.active: x_local_shape = slicing.compute_subshape(P_x.shape, P_x.index, x_global_shape) x = np.zeros(x_local_shape) + P_x.rank + 1 x = torch.from_numpy(x) x.requires_grad = True print(f"rank {P_world.rank}; index {P_x.index}; value {x}") # Apply the layer. # # Output tensor will be (on a 2 x 2 partition): # [ [ 1 1 1 | 1 1 ] # [ 1 1 1 | 1 1 ] # [ 1 1 1 | 1 1 ] # [ 1 1 1 | 1 1 ] # ------------- # [ 1 1 1 | 1 1 ] # [ 1 1 1 | 1 1 ]
def backward(ctx, grad_output): P_union = ctx.P_union x_global_shape = ctx.x_global_shape P_x = ctx.P_x in_data = ctx.in_data in_buffers = ctx.in_buffers P_y = ctx.P_y out_data = ctx.out_data out_buffers = ctx.out_buffers preserve_batch = ctx.preserve_batch dtype = ctx.dtype input_requires_grad = ctx.input_requires_grad requests = [] # Default everyone to output None if preserve_batch: grad_input = zero_volume_tensor(grad_output.shape[0]) else: grad_input = zero_volume_tensor() # Recv my input parts recv_count = 0 if P_x.active: for (sl, sz, partner), buff in zip(in_data, in_buffers): if buff is not None: req = P_union.comm.Irecv(buff, source=partner, tag=113) requests.append(req) else: # We add this if there is no recv so that the indices of # the requests array match the indices of in_data and # in_buffers. requests.append(MPI.REQUEST_NULL) recv_count += 1 # Pack and send my input parts send_count = 0 if P_y.active: grad_output_numpy = grad_output.detach().numpy() for (sl, sz, partner), buff in zip(out_data, out_buffers): if buff is not None: np.copyto(buff, grad_output_numpy[tuple(sl)].ravel()) req = P_union.comm.Isend(buff, dest=partner, tag=113) requests.append(req) else: # We add this for symmetry, but don't really need it. requests.append(MPI.REQUEST_NULL) send_count += 1 if P_x.active: index = P_x.index x_local_shape = compute_subshape(P_x.shape, index, x_global_shape) # TODO(#25): The dtype should not be fixed, but correcting this is # a thing that needs to be resolved globally. grad_input = np.zeros(x_local_shape, dtype=dtype) # Unpack the received data as it arrives completed_count = 0 while (completed_count < len(requests)): status = MPI.Status() index = MPI.Request.Waitany(requests, status) # In MPI, we don't get the index out if the request is an # instance of MPI.REQUEST_NULL, instead MPI.UNDEFINED is returned. if P_x.active and index < recv_count and index != MPI.UNDEFINED: # Unpack my output parts sl, sz, partner = in_data[index] buff = in_buffers[index] if buff is not None: sh = grad_input[tuple(sl)].shape # This would normally be an add into the grad_input tensor # but we just created it, so a copy is sufficient. np.copyto(grad_input[tuple(sl)], buff.reshape(sh)) completed_count += 1 if P_x.active: grad_input = torch.from_numpy(grad_input) grad_input.requires_grad = input_requires_grad return grad_input, None, None, None, None, None, None, None, None, None, None
def test_transpose_adjoint(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, P_y_ranks, P_y_shape, x_global_shape, balanced): import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.transpose import DistributedTranspose from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor device = torch.device('cuda' if use_cuda else 'cpu') # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) # Create the partitions P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) P_y_base = P_world.create_partition_inclusive(P_y_ranks) P_y = P_y_base.create_cartesian_topology_partition(P_y_shape) # The global tensor size is the same for x and y layer = DistributedTranspose(P_x, P_y, preserve_batch=False) layer = layer.to(device) # Forward Input x = zero_volume_tensor(device=device) if P_x.active: if balanced: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) else: quotient = np.atleast_1d(x_global_shape) // np.atleast_1d( P_x_shape) remainder = np.atleast_1d(x_global_shape) % np.atleast_1d( P_x_shape) loc = np.where(P_x.index == 0) x_local_shape = quotient.copy() x_local_shape[loc] += remainder[loc] x = torch.randn(*x_local_shape, device=device) x.requires_grad = True # Adjoint Input dy = zero_volume_tensor(device=device) if P_y.active: y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape) dy = torch.randn(*y_local_shape, device=device) # y = F @ x y = layer(x) # dx = F* @ dy y.backward(dy) dx = x.grad x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy) P_world.deactivate() P_x_base.deactivate() P_x.deactivate() P_y_base.deactivate() P_y.deactivate()
dilation = [1, 1, 1, 1] exchange_info = mockup_conv_layer._compute_exchange_info(x_global_shape, kernel_size, stride, padding, dilation, P_x.active, P_x.shape, P_x.index) halo_shape = exchange_info[0] recv_buffer_shape = exchange_info[1] send_buffer_shape = exchange_info[2] x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) value = (1 + rank) * (10 ** rank) a = np.full(shape=x_local_shape, fill_value=value, dtype=float) forward_input_padnd_layer = PadNd(halo_shape.astype(int), value=0, partition=P_x) adjoint_input_padnd_layer = PadNd(halo_shape.astype(int), value=value, partition=P_x) t = torch.tensor(a, requires_grad=True) t_forward_input = forward_input_padnd_layer.forward(t) t_adjoint_input = adjoint_input_padnd_layer.forward(t) halo_layer = HaloExchange(P_x, halo_shape, recv_buffer_shape, send_buffer_shape) print_sequential(cart_comm, f'rank = {rank}, t_forward_input =\n{t_forward_input.int()}')
def test_halo_exchange_adjoint(barrier_fence_fixture, comm_split_fixture, P_x_ranks, P_x_shape, x_global_shape, kernel_size, stride, padding, dilation, MockKernelStyle): import numpy as np import torch from distdl.backends.mpi.partition import MPIPartition from distdl.nn.halo_exchange import HaloExchange from distdl.nn.padnd import PadNd from distdl.utilities.slicing import compute_subshape from distdl.utilities.torch import zero_volume_tensor # Isolate the minimum needed ranks base_comm, active = comm_split_fixture if not active: return P_world = MPIPartition(base_comm) P_x_base = P_world.create_partition_inclusive(P_x_ranks) P_x = P_x_base.create_cartesian_topology_partition(P_x_shape) x_global_shape = np.asarray(x_global_shape) kernel_size = np.asarray(kernel_size) stride = np.asarray(stride) padding = np.asarray(padding) dilation = np.asarray(dilation) halo_shape = None recv_buffer_shape = None send_buffer_shape = None if P_x.active: mockup_layer = MockKernelStyle() exchange_info = mockup_layer._compute_exchange_info(x_global_shape, kernel_size, stride, padding, dilation, P_x.active, P_x.shape, P_x.index) halo_shape = exchange_info[0] recv_buffer_shape = exchange_info[1] send_buffer_shape = exchange_info[2] pad_layer = PadNd(halo_shape, value=0) halo_layer = HaloExchange(P_x, halo_shape, recv_buffer_shape, send_buffer_shape) x = zero_volume_tensor(x_global_shape[0]) if P_x.active: x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape) x = torch.tensor(np.random.randn(*x_local_shape)) x = pad_layer.forward(x) x.requires_grad = True dy = zero_volume_tensor(x_global_shape[0]) if P_x.active: dy = torch.tensor(np.random.randn(*x.shape)) x_clone = x.clone() dy_clone = dy.clone() # x_clone is be modified in place by halo_layer, but we assign y to # reference it for clarity y = halo_layer(x_clone) # dy_clone is modified in place by halo_layer-adjoint, but we assign dx to # reference it for clarity y.backward(dy_clone) dx = dy_clone x = x.detach() dx = dx.detach() dy = dy.detach() y = y.detach() check_adjoint_test_tight(P_world, x, dx, y, dy)