def _distdl_module_setup(self, input): r"""AllSumReduce module setup function. Constructs the necessary partition functions to implement the above described reduction pattern. This function performs collective communication across the input and output partitions. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ if not (self.P_x.active): return # If it is not an identity, we need actual Partitions to do the work. if not self.identity: self.P_allreduce = self.P_x.create_allreduction_partition( self.axes_reduce) self.input_tensor_structure = TensorStructure(input[0]) self.output_tensor_structure = self.input_tensor_structure self._distdl_is_setup = True self._input_tensor_structure = TensorStructure(input[0])
def _distdl_module_teardown(self, input): r"""Broadcast module teardown function. Nullifies the necessary partition functions. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ # Reset all of the buffers and communication objects self.P_send.deactivate() self.P_recv.deactivate() # Reset any data stored about the tensor self.input_tensor_structure = TensorStructure() self.output_tensor_structure = TensorStructure() # Reset any info about the input self._distdl_is_setup = False self._input_tensor_structure = TensorStructure()
def _distdl_module_teardown(self, input): r"""Transpose module teardown function. Deallocates buffers safely. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ # Reset all of the buffers and communication objects self.P_x_to_y_overlaps = [] self.P_y_to_x_overlaps = [] self.P_x_to_y_buffers = None self.P_y_to_x_buffers = None # Reset any info about the input self._distdl_is_setup = False self._input_tensor_structure = TensorStructure()
def __init__(self, P_x, buffer_manager=None, size=None, scale_factor=None, mode='linear', align_corners=False): super(DistributedUpsample, self).__init__() if mode == 'cubic': raise NotImplementedError( 'Cubic interpolation is not implemented.') if size is None and scale_factor is None: raise ValueError("One of `size` or `scale_factor` must be set.") if size is not None and scale_factor is not None: raise ValueError( "Only one of `size` or `scale_factor` may be set.") # P_x is 1 x 1 x P_d-1 x ... x P_0 self.P_x = P_x # Back-end specific buffer manager for economic buffer allocation if buffer_manager is None: buffer_manager = self._distdl_backend.BufferManager() elif type(buffer_manager) is not self._distdl_backend.BufferManager: raise ValueError("Buffer manager type does not match backend.") self.buffer_manager = buffer_manager if not self.P_x.active: return # Do this before checking serial so that the layer works properly # in the serial case # self.pool_layer = self.TorchPoolType(*args, **kwargs) self.mode = mode self.align_corners = align_corners self.size = size self.scale_factor = scale_factor # Local input and output tensor structures, defined when layer is called self.input_tensor_structure = TensorStructure() self.output_tensor_structure = TensorStructure() # We need the actual sizes to determine the interpolation layer self.interp_layer = None # For the halo layer we also defer construction, so that we can have # the halo shape for the input. The halo will allocate its own # buffers, but it needs this information at construction to be able # to do this in the pre-forward hook. self.halo_layer = None # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_tensor_structure = TensorStructure()
def __init__(self, P_x, P_y, transpose_src=False, transpose_dest=False, preserve_batch=True): super(Broadcast, self).__init__() # Partition of input tensor. self.P_x = P_x # Partition of output tensor. self.P_y = P_y # Transpose the input partition prior to the broadcast. self.transpose_src = transpose_src # Transpose the output partition prior to the broadcast. self.transpose_dest = transpose_dest # Indicates if batch size should be preserved for zero-volume outputs. self.preserve_batch = preserve_batch # Indicates if broadcast requires any data movement. self.identity = False # Partition for sharing copy of local data. self.P_send = self._distdl_backend.Partition() # Partition for receiving copy of local data. self.P_recv = self._distdl_backend.Partition() # Other info needed by the functions # Structure of the input tensor (shape, dtype, requires_grad, etc). self.input_tensor_structure = TensorStructure() # Structure of the output tensor (shape, dtype, requires_grad, etc). self.output_tensor_structure = TensorStructure() # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_tensor_structure = TensorStructure() # The identity case is if the partitions are of size 1, # or they are the same partition and neither is tranposed, # or they are the same partition and both are transposed. if self.P_x == self.P_y: if self.P_x.size == 1: self.identity = True elif (self.transpose_dest and self.transpose_src) or \ (not self.transpose_dest and not self.transpose_src): self.identity = True
def __init__(self, P_x, halo_shape, recv_buffer_shape, send_buffer_shape, buffer_manager=None): super(HaloExchange, self).__init__() self.P_x = P_x self.halo_shape = halo_shape self.recv_buffer_shape = recv_buffer_shape self.send_buffer_shape = send_buffer_shape self.neighbor_ranks = None if self.P_x.active: self.neighbor_ranks = self.P_x.neighbor_ranks(self.P_x.rank) self.slices = None self.buffers = None # Back-end specific buffer manager for economic buffer allocation if buffer_manager is None: buffer_manager = self._distdl_backend.BufferManager() elif type(buffer_manager) is not self._distdl_backend.BufferManager: raise ValueError("Buffer manager type does not match backend.") self.buffer_manager = buffer_manager # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_tensor_structure = TensorStructure() # Get some types and functions from the back-end self.allocate_halo_exchange_buffers = self._distdl_backend.halo_exchange.allocate_halo_exchange_buffers
def _distdl_module_setup(self, input): r"""Distributed loss setup function. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ if not self.P_x.active: return # If the reduction mode is "mean", we need the total size of the # global input tensor. if self.reduction == "mean": N_local = np.prod(input[0].shape) self.normalization_factor = int( self.P_x.allreduce_data(np.asarray(N_local))) self._distdl_is_setup = True self._input_tensor_structure = TensorStructure(input[0])
def _distdl_module_setup(self, input): r"""Distributed KL Divergence loss setup function. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ if not self.P_x.active: return # If the reduction mode is "mean", we need the total size of the # global input tensor. if self.reduction == "mean": N_local = np.prod(input[0].shape) self.normalization_factor = int( self.P_x.allreduce_data(np.asarray(N_local))) # If the reduction mode is "batchmean", we need the batch size, which # means we only add the shapes along the batch axis. elif self.reduction == "batchmean": if all(coord == 0 for coord in self.P_x.index[1:]): N_local = input[0].shape[0] else: N_local = 0 self.normalization_factor = int( self.P_x.allreduce_data(np.asarray(N_local))) self._distdl_is_setup = True self._input_tensor_structure = TensorStructure(input[0])
def _distdl_module_teardown(self, input): # Reset all of the buffers and communication objects self.slices = None self.buffers = None # Reset any info about the input self._distdl_is_setup = False self._input_tensor_structure = TensorStructure()
def __init__(self, P_x, axes_reduce=None, axes_keep=None): super(AllSumReduce, self).__init__() # Partition of input and output tensor. self.P_x = P_x # Partition dimensions along which the all-reduction takes place. # While we compute both terms, `axes_reduce` is used internally. if axes_reduce is None and axes_keep is None: raise ValueError( "One of `axes_reduce` or `axes_keep` must be specified.") elif axes_reduce is not None and axes_keep is not None: raise ValueError( "Only one of `axes_reduce` or `axes_keep` may be specified.") elif axes_reduce is not None: self.axes_reduce = axes_reduce self.axes_keep = [ d for d in range(P_x.dim) if d not in axes_reduce ] elif axes_keep is not None: self.axes_reduce = [ d for d in range(P_x.dim) if d not in axes_keep ] self.axes_keep = axes_keep # Indicates if broadcast requires any data movement. self.identity = False # Partition for performing all-reduction. self.P_allreduce = self._distdl_backend.Partition() # Structure of the input tensor (shape, dtype, requires_grad, etc). self.input_tensor_structure = TensorStructure() # Structure of the output tensor (shape, dtype, requires_grad, etc). self.output_tensor_structure = TensorStructure() # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_tensor_structure = TensorStructure() # The identity case is if the partition is of size 1, if self.P_x.size == 1: self.identity = True
def assemble_global_tensor_structure(local_tensor_structure, P_in, P_out=None): global_tensor_structure = TensorStructure() global_tensor_shape = None intID_dtype = None requires_grad_int = None if P_in.active: # Assemble the global shape global_tensor_shape = np.zeros(P_in.dim, dtype=np.int) for i in range(P_in.dim): keep = [False] * P_in.dim keep[i] = True P_sub = P_in.create_cartesian_subtopology_partition(keep) v0 = np.atleast_1d(int(local_tensor_structure.shape[i])) v1 = np.zeros(1, dtype=np.int) P_sub._comm.Allreduce(v0, v1, op=MPI.SUM) global_tensor_shape[i] = v1[0] # Free the subtopology resources P_sub.deactivate() # Get a communicable integer representing the dtype intID_dtype = torch_to_intID_dtype_dict[local_tensor_structure.dtype] intID_dtype = np.array([intID_dtype], dtype=np.int) requires_grad_int = np.array([-1], dtype=np.int) requires_grad_int[0] = 1 if local_tensor_structure.requires_grad else 0 global_tensor_structure.shape = global_tensor_shape global_tensor_structure.dtype = local_tensor_structure.dtype global_tensor_structure.requires_grad = local_tensor_structure.requires_grad if P_out is not None and P_out.active: # Share the shape global_tensor_structure.shape = P_out.broadcast_data( global_tensor_shape, P_data=P_in) # Share the dtype intID_dtype = P_out.broadcast_data(intID_dtype, P_data=P_in) global_tensor_structure.dtype = intID_to_torch_dtype_dict[ intID_dtype[0]] # Share the requires_grad status requires_grad_int = P_out.broadcast_data(requires_grad_int, P_data=P_in) global_tensor_structure.requires_grad = bool(requires_grad_int[0]) return global_tensor_structure
def _distdl_module_setup(self, input): if self.P_x.active: x_local_shape = input[0].shape self.slices = self._assemble_slices(x_local_shape, self.recv_buffer_shape, self.send_buffer_shape) self.buffers = self.allocate_halo_exchange_buffers( self.buffer_manager, self.slices, self.recv_buffer_shape, self.send_buffer_shape, input[0].dtype) self._distdl_is_setup = True self._input_tensor_structure = TensorStructure(input[0])
def _distdl_input_changed(self, input): r"""Determine if the structure of inputs has changed. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ new_tensor_structure = TensorStructure(input[0]) return self._input_tensor_structure != new_tensor_structure
def _distdl_module_setup(self, input): r"""Broadcast module setup function. Constructs the necessary partition functions to implement the above described broadcast pattern. This function performs collective communication across the input and output partitions. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ if not (self.P_x.active or self.P_y.active): return # If it is not an identity, we need actual Partitions to do the work. if not self.identity: bcast_partitions = self.P_x.create_broadcast_partition_to( self.P_y, self.transpose_src, self.transpose_dest) self.P_send = bcast_partitions[0] self.P_recv = bcast_partitions[1] self.input_tensor_structure = TensorStructure(input[0]) self.output_tensor_structure = \ self._distdl_backend.broadcast_tensor_structure(self.input_tensor_structure, self.P_send, self.P_recv) self._distdl_is_setup = True self._input_tensor_structure = TensorStructure(input[0])
def _distdl_module_teardown(self, input): r"""Distributed (channel) convolution module teardown function. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ # Reset any info about the input self._distdl_is_setup = False self._input_tensor_structure = TensorStructure()
def _distdl_module_teardown(self, input): r"""Distributed loss teardown function. Nullifies the necessary partition functions. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ self.normalization_factor = 1 # Reset any info about the input self._distdl_is_setup = False self._input_tensor_structure = TensorStructure()
def _distdl_module_setup(self, input): r"""Distributed (channel) convolution module setup function. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ # No setup is needed if the worker is not doing anything for this # layer. if not (self.P_x.active or self.P_y.active or self.P_w.active): return if self.serial: return self._distdl_is_setup = True self._input_tensor_structure = TensorStructure(input[0])
def __init__(self, P_x, P_y, preserve_batch=True, buffer_manager=None): super(DistributedTranspose, self).__init__() # Global structure of the input tensor, assembled when layer is called self.global_input_tensor_structure = TensorStructure() # Local input and output tensor structures, defined when layer is called self.input_tensor_structure = TensorStructure() self.output_tensor_structure = TensorStructure() # Partition of input tensor. self.P_x = P_x # Partition of output tensor. self.P_y = P_y # Indicates if batch size should be preserved for zero-volume outputs. self.preserve_batch = preserve_batch # List of meta data describing copies of subvolumes of input tensor # out of the current worker self.P_x_to_y_overlaps = [] # List of meta data describing copies of subvolumes of output tensor # into the current worker self.P_y_to_x_overlaps = [] # Back-end specific buffer manager for economic buffer allocation if buffer_manager is None: buffer_manager = self._distdl_backend.BufferManager() elif type(buffer_manager) is not self._distdl_backend.BufferManager: raise ValueError("Buffer manager type does not match backend.") self.buffer_manager = buffer_manager # List of buffers for copying data to other workers self.P_x_to_y_buffers = None # List of buffers for copying data from other workers self.P_y_to_x_buffers = None # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_tensor_structure = TensorStructure() # Otherwise, we need the union of the input and output partitions # so that data can be copied across them. P_union = self._distdl_backend.Partition() if P_x.active or P_y.active: P_union = P_x.create_partition_union(P_y) self.P_union = P_union # Setup these variables incase the current worker is inactive in # the union. self.P_x_shape = None self.P_y_shape = None self.union_indices = None if not P_union.active: return # All active workers need the shapes of both partitions so that buffer # sizes and subtensor overlaps can be computed. data = None if self.P_x.active: data = self.P_x.shape self.P_x_shape = self.P_union.broadcast_data(data, P_data=self.P_x) data = None if self.P_y.active: data = self.P_y.shape self.P_y_shape = self.P_union.broadcast_data(data, P_data=self.P_y) if len(self.P_x_shape) != len(self.P_y_shape): raise ValueError( "Input and output partition must be same dimension.") # Share the two indices with every worker in the union. The first # column of data contains the source index and the second contains # the destination index. data = np.array([P_x.rank if P_x.active else -1], dtype=np.int) self.P_x_ranks = P_union.allgather_data(data) data = np.array([P_y.rank if P_y.active else -1], dtype=np.int) self.P_y_ranks = P_union.allgather_data(data) # Get some types and functions from the back-end self.allocate_transpose_buffers = self._distdl_backend.transpose.allocate_transpose_buffers
def broadcast_tensor_structure(input_tensor_structure, P_send, P_recv): output_tensor_structure = TensorStructure() if not P_send.active and not P_recv.active: return output_tensor_structure requests = [] if P_send.active: # Share the torch dtype code, converted to an int. intID_dtype = torch_to_intID_dtype_dict[input_tensor_structure.dtype] send_intID_dtype = np.array([intID_dtype], dtype=np.int) req = P_send._comm.Iallreduce(MPI.IN_PLACE, send_intID_dtype, op=MPI.MAX) requests.append(req) # Need to send non-Python types, so convert the boolean temporarily rg_int_send = np.array([-1], dtype=np.int) rg_int_send[0] = 1 if input_tensor_structure.requires_grad else 0 req = P_send._comm.Iallreduce(MPI.IN_PLACE, rg_int_send, op=MPI.MAX) requests.append(req) # Sending processes know the shape, so they can send a copy of the # data. We will ignore this variable later. send_tensor_dim = np.array([len(input_tensor_structure.shape)], dtype=np.int) req = P_send._comm.Iallreduce(MPI.IN_PLACE, send_tensor_dim, op=MPI.MAX) requests.append(req) # Similarly, sending processes know the tensor shape, so they can send # a copy of it, but we will not use that copy for our actual return # value. send_tensor_shape = np.array(input_tensor_structure.shape, dtype=np.int) req = P_send._comm.Iallreduce(MPI.IN_PLACE, send_tensor_shape, op=MPI.MAX) requests.append(req) # If the process is a receiving process, but doesn't already know the data # because it is the _same_ sending process, then we receive the results. # If it is a receiving process that sent data to a different set of # processes, we still have to complete the receive, even though later we # will not use that data. if (P_send != P_recv) and P_recv.active: # Everyone needs to receive these two values, but we don't need them # for future communication in this function so we can defer receiving # the data. recv_intID_dtype = np.array([-1], dtype=np.int) req = P_recv._comm.Iallreduce(MPI.IN_PLACE, recv_intID_dtype, op=MPI.MAX) requests.append(req) rg_int_recv = np.array([-1], dtype=np.int) req = P_recv._comm.Iallreduce(MPI.IN_PLACE, rg_int_recv, op=MPI.MAX) requests.append(req) # We need this value for the next communication, so we have to wait # for it to complete before moving on. recv_tensor_dim = np.array([-1], dtype=np.int) req = P_recv._comm.Iallreduce(MPI.IN_PLACE, recv_tensor_dim, op=MPI.MAX) req.Wait() recv_tensor_shape = np.zeros(recv_tensor_dim, dtype=np.int) recv_tensor_shape[:] = -1 req = P_recv._comm.Iallreduce(MPI.IN_PLACE, recv_tensor_shape, op=MPI.MAX) requests.append(req) # Make sure all requests, including the final recv all reduce complete # before receiving processes can actually copy the data out. MPI.Request.Waitall(requests) # Wait until the communication is complete to set these values. Only # receiving ranks that do not have the data originally should enter here. if P_recv.active and (P_send != P_recv): output_tensor_structure.shape = torch.Size(recv_tensor_shape) output_tensor_structure.dtype = intID_to_torch_dtype_dict[ recv_intID_dtype[0]] output_tensor_structure.requires_grad = bool(rg_int_recv[0]) elif P_send == P_recv: output_tensor_structure.shape = input_tensor_structure.shape output_tensor_structure.dtype = input_tensor_structure.dtype output_tensor_structure.requires_grad = input_tensor_structure.requires_grad # Finally, every active worker should have valid data. Any sending rank # created it from input data. Any receving _only_ rank used what it was # given. return output_tensor_structure
def __init__(self, P_x, P_y, P_w, in_channels, out_channels, kernel_size, stride=1, padding=0, padding_mode='zeros', dilation=1, groups=1, bias=True, buffer_manager=None): super(DistributedGeneralConvBase, self).__init__() # P_x is 1 x P_ci x P_d-1 x ... x P_0 self.P_x = P_x # P_y is 1 x P_co x P_d-1 x ... x P_0 self.P_y = P_y # P_w is P_co x P_ci x P_d-1 x ... x P_0 self.P_w = P_w # Back-end specific buffer manager for economic buffer allocation if buffer_manager is None: buffer_manager = self._distdl_backend.BufferManager() elif type(buffer_manager) is not self._distdl_backend.BufferManager: raise ValueError("Buffer manager type does not match backend.") self.buffer_manager = buffer_manager # Even inactive workers need some partition union self.P_union = self._distdl_backend.Partition() if not (self.P_x.active or self.P_y.active or self.P_w.active): return self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = self._expand_parameter(kernel_size) self.stride = self._expand_parameter(stride) self.padding = self._expand_parameter(padding) self.padding_mode = padding_mode self.dilation = self._expand_parameter(dilation) self.groups = groups self.use_bias = bias # This guarantees that P_union rank 0 has the kernel size, stride, # padding, and dilation factors P_union_temp = P_w.create_partition_union(P_x) self.P_union = P_union_temp.create_partition_union(P_y) # Release the temporary resources P_union_temp.deactivate() # Ensure that all workers have the full size and structure of P_w P_w_shape = None if self.P_union.rank == 0: P_w_shape = np.array(P_w.shape, dtype=np.int) P_w_shape = self.P_union.broadcast_data(P_w_shape, root=0) P_co = P_w_shape[0] P_ci = P_w_shape[1] P_channels = [P_co, P_ci] # Ensure that P_x and P_w are correctly aligned. We also produce a # new P_x that is shaped like 1 x P_ci x P_d-1 x ... x P_0, to assist # with broadcasts. P_x_new_shape = [] if self.P_x.active: if(np.any(P_x.shape[2:] != P_w_shape[2:])): raise ValueError("Spatial components of P_x and P_w must match.") if P_w_shape[1] != P_x.shape[1]: raise ValueError("Index 2 of P_w dimension must match input channel partition.") P_x_new_shape = list(P_x.shape) P_x_new_shape.insert(1, 1) # Currently a hack, removing the batch dimension because P_w does # not have one. This is OK because we assume there are no partitions # in the batch dimension. P_x_new_shape = np.asarray(P_x_new_shape[1:], dtype=int) # For the purposes of this layer, we re-cast P_x to have the extra # dimension. This has no impact outside of the layer or on the results. self.P_x = self.P_x.create_cartesian_topology_partition(P_x_new_shape) # Ensure that P_y and P_w are correctly aligned. We also produce a # new P_y that is shaped like 1 x P_ci x P_d-1 x ... x P_0, to assist # with broadcasts. P_y_new_shape = [] if self.P_y.active: if(np.any(P_y.shape[2:] != P_w_shape[2:])): raise ValueError("Spatial components of P_y and P_w must match.") if P_w_shape[0] != P_y.shape[1]: raise ValueError("Index 1 of P_w dimension must match output channel partition.") P_y_new_shape = list(P_y.shape) P_y_new_shape.insert(2, 1) # Currently a hack, removing the batch dimension because P_w does # not have one. This is OK because we assume there are no partitions # in the batch dimension. P_y_new_shape = np.asarray(P_y_new_shape[1:], dtype=int) # For the purposes of this layer, we re-cast P_x to have the extra # dimension. This has no impact outside of the layer or on the results. self.P_y = self.P_y.create_cartesian_topology_partition(P_y_new_shape) P_spatial = P_w_shape[2:] self.serial = self.P_w.size == 1 if self.serial: self.conv_layer = self.TorchConvType(in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, padding_mode=self.padding_mode, dilation=self.dilation, groups=self.groups, bias=self.use_bias) self.weight = self.conv_layer.weight self.bias = self.conv_layer.bias return # Need to figure out any padding necessary to handle global padding. # This is only on the input tensor. The convolution will not use # any implicit padding, so the work partition does not need it. if self.P_x.active: dims = len(self.P_x.shape) # We will be using global padding to compute local padding, # so expand it to a numpy array global_padding = np.pad(self.padding, pad_width=(dims-len(self.padding), 0), mode='constant', constant_values=0) self.global_padding = global_padding pad_left_right = self.global_padding.reshape((dims, 1)) + np.zeros((dims, 2), dtype=np.int) self.local_padding = self._compute_local_padding(pad_left_right) # Workers can either store the learnable weights and bias, or they # need copies of it. self.receives_weight = False self.stores_weight = False self.receives_bias = False self.stores_bias = False # Determine root partitions, initialize weights there if self.P_w.active: # All of P_w always receives the weight self.receives_weight = True # This subset is taken to be the origin of the spartial component w_root_subset = [] for i, c in enumerate(range_index(P_w.shape)): c = np.asarray(c) # Find the P_co x P_ci x 1 x ... x 1 subset to store the weights if np.all(c[2:] == 0): w_root_subset.append(i) P_wr_base = self.P_w.create_partition_inclusive(w_root_subset) # ones are needed so the broadcast will work self.P_wr = P_wr_base.create_cartesian_topology_partition([P_co, P_ci] + [1]*len(P_spatial)) self.stores_weight = self.P_wr.active # Release temporary resources P_wr_base.deactivate() b_subset = [] for i, c in enumerate(range_index(P_w.shape)): c = np.asarray(c) # Find the P_co x 1 x P_0 x ... x P_D-1 subset that needs # biases in its calculation. This is everywhere that the input # channels is rank 0. if c[1] == 0: b_subset.append(i) P_b_base = self.P_w.create_partition_inclusive(b_subset) self.P_b = P_b_base.create_cartesian_topology_partition([P_co] + [1] + list(P_spatial)) self.receives_bias = self.P_b.active and self.use_bias # Release temporary resources P_b_base.deactivate() # Now find the subset of _that_ which actually stores the # learnable parameter. b_root_subset = [] for i, c in enumerate(range_index(P_w.shape)): c = np.asarray(c) # Find the P_co x 1 x 1 x ... x 1 subset to store the biases if np.all(c[1:] == 0): b_root_subset.append(i) P_br_base = self.P_w.create_partition_inclusive(b_root_subset) # ones are needed so the broadcast will work self.P_br = P_br_base.create_cartesian_topology_partition([P_co] + [1] + [1]*len(P_spatial)) self.stores_bias = self.P_br.active and self.use_bias # Release temporary resources P_br_base.deactivate() # Correct the input arguments based on local properties # This ensures that the in and out channels are correctly shared. local_co, local_ci = compute_subshape(P_channels, P_w.index[0:2], [out_channels, in_channels]) self.conv_layer = self.TorchConvType(in_channels=local_ci, out_channels=local_co, kernel_size=self.kernel_size, stride=self.stride, padding=0, padding_mode='zeros', dilation=self.dilation, groups=groups, bias=self.receives_bias) # If we store the weight it is a learnable parameter iff it is # learnable by default in the layer, which it is. if self.stores_weight: self.weight = torch.nn.Parameter(self.conv_layer.weight.detach()) else: self.register_buffer('weight', zero_volume_tensor()) # This always exists so we can copy the property self.weight.requires_grad = self.conv_layer.weight.requires_grad # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2 new_weight = self.conv_layer.weight.detach() * 0 new_weight.requires_grad = self.conv_layer.weight.requires_grad del self.conv_layer.weight self.conv_layer.weight = new_weight # If we store the bias, it is a learnable parameter iff it is # learnable by default in the layer, which is only true if it # exists. if self.stores_bias: self.bias = torch.nn.Parameter(self.conv_layer.bias.detach()) else: if self.use_bias: self.register_buffer('bias', zero_volume_tensor()) else: self.register_buffer('bias', None) # This does not always exist, but when it does we can copy the # property. if self.receives_bias: self.bias.requires_grad = self.conv_layer.bias.requires_grad # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2 new_bias = self.conv_layer.bias.detach() * 0 new_bias.requires_grad = self.conv_layer.bias.requires_grad del self.conv_layer.bias self.conv_layer.bias = new_bias else: # Workers not in P_w don't have a weight or bias. self.register_buffer('weight', zero_volume_tensor()) if self.use_bias: self.register_buffer('bias', zero_volume_tensor()) else: self.register_buffer('bias', None) # Now we need to share the kernel structure. The size of the kernel # is always the spatial dimensions. self.conv_kernel_size = None self.conv_stride = None self.conv_padding = None self.conv_dilation = None # By construction, rank 0 of the union should always have all of this # information, because it will always construct a local conv layer. We # rely on the local conv layer to properly fill out this information # from the defaults. This info is required for all workers on the # input and output partitions because it is needed to construct the # halos. Rank 0 in the union shares it with everyone. if self.P_union.rank == 0: self.conv_kernel_size = np.array(self.conv_layer.kernel_size, dtype=np.int) self.conv_stride = np.array(self.conv_layer.stride, dtype=np.int) self.conv_padding = np.array(self.conv_layer.padding, dtype=np.int) self.conv_dilation = np.array(self.conv_layer.dilation, dtype=np.int) self.conv_kernel_size = self.P_union.broadcast_data(self.conv_kernel_size, root=0) self.conv_stride = self.P_union.broadcast_data(self.conv_stride, root=0) self.conv_padding = self.P_union.broadcast_data(self.conv_padding, root=0) self.conv_dilation = self.P_union.broadcast_data(self.conv_dilation, root=0) # We need to be able to remove some data from the input to the conv # layer but again need to defer. self.needed_slices = None # For the halo layer we also defer construction, so that we can have # the halo shape for the input. The halo will allocate its own # buffers, but it needs this information at construction to be able # to do this in the pre-forward hook. self.halo_layer = None # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_tensor_structure = TensorStructure() # Some layers, those that require no information about the input # tensor to setup, can be built now. if P_w.active: self.w_broadcast = Broadcast(self.P_wr, self.P_w, preserve_batch=False) if self.receives_bias or self.stores_bias: self.b_broadcast = Broadcast(self.P_br, self.P_b, preserve_batch=False) self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True) self.y_sum_reduce = SumReduce(self.P_w, self.P_y, preserve_batch=True)
def __init__(self, P_x, kernel_size, stride=1, padding=0, dilation=1, buffer_manager=None): super(DistributedPoolBase, self).__init__() # P_x is 1 x 1 x P_d-1 x ... x P_0 self.P_x = P_x self.is_avg = self.TorchPoolType in [torch.nn.AvgPool1d, torch.nn.AvgPool2d, torch.nn.AvgPool3d] # Back-end specific buffer manager for economic buffer allocation if buffer_manager is None: buffer_manager = self._distdl_backend.BufferManager() elif type(buffer_manager) is not self._distdl_backend.BufferManager: raise ValueError("Buffer manager type does not match backend.") self.buffer_manager = buffer_manager if not self.P_x.active: return dims = len(self.P_x.shape) self.kernel_size = self._expand_parameter(kernel_size) self.stride = self._expand_parameter(stride) self.padding = self._expand_parameter(padding) self.dilation = self._expand_parameter(dilation) if self.is_avg and not all(x == 1 for x in self.dilation): raise ValueError('dilation is only supported for MaxPooling layers.') # PyTorch does not support dilation for AvgPooling layers if self.is_avg: self.pool_layer = self.TorchPoolType(kernel_size=self.kernel_size, stride=self.stride, padding=0) else: self.pool_layer = self.TorchPoolType(kernel_size=self.kernel_size, stride=self.stride, padding=0, dilation=self.dilation) # We will be using global padding to compute local padding, # so expand it to a numpy array global_padding = np.pad(self.padding, pad_width=(dims-len(self.padding), 0), mode='constant', constant_values=0) self.global_padding = global_padding pad_left_right = self.global_padding.reshape((dims, 1)) + np.zeros((dims, 2), dtype=np.int) self.local_padding = self._compute_local_padding(pad_left_right) # We need to be able to remove some data from the input to the conv # layer. self.needed_slices = None # For the halo layer we also defer construction, so that we can have # the halo shape for the input. The halo will allocate its own # buffers, but it needs this information at construction to be able # to do this in the pre-forward hook. self.halo_layer = None # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_tensor_structure = TensorStructure()
def _distdl_module_setup(self, input): r"""Distributed (feature) pooling module setup function. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ self._distdl_is_setup = True self._input_tensor_structure = TensorStructure(input[0]) if not self.P_x.active: return # To compute the halo regions and interpolation, we need the global # tensor shape. This is not available until when the input is # provided. global_input_tensor_structure = \ self._distdl_backend.assemble_global_tensor_structure(input[0], self.P_x) if self.size is None: global_output_tensor_shape = torch.as_tensor( global_input_tensor_structure.shape).to(torch.float64) global_output_tensor_shape[2:] *= self.scale_factor # I prefer ceil(), torch uses floor(), so we go with floor for consistency global_output_tensor_shape = torch.Size( torch.floor(global_output_tensor_shape).to(torch.int64)) else: if len(self.size) != len(global_input_tensor_structure.shape): raise ValueError( "Provided size does not match input tensor dimension.") global_output_tensor_shape = torch.Size(torch.as_tensor(self.size)) global_output_tensor_structure = TensorStructure() global_output_tensor_structure.shape = global_output_tensor_shape # Using that information, we can get there rest of the halo information exchange_info = self._compute_exchange_info( self.P_x, global_input_tensor_structure, global_output_tensor_structure, self.scale_factor, self.mode, self.align_corners) halo_shape = exchange_info[0] recv_buffer_shape = exchange_info[1] send_buffer_shape = exchange_info[2] needed_ranges = exchange_info[3] self.halo_shape = halo_shape # We can also set up part of the halo layer. self.halo_layer = HaloExchange(self.P_x, halo_shape, recv_buffer_shape, send_buffer_shape, buffer_manager=self.buffer_manager) # We have to select out the "unused" entries. Sometimes there can # be "negative" halos. self.needed_slices = assemble_slices(needed_ranges[:, 0], needed_ranges[:, 1]) # TODO #176: This block to compute the start and stop index of the # post-halo exchanged input can be cleaned up, as it is a duplicate of # calculation in the halo layer itself _slice = tuple([slice(i, i + 1) for i in self.P_x.index] + [slice(None)]) x_subtensor_shapes = compute_subtensor_shapes_balanced( global_input_tensor_structure, self.P_x.shape) x_subtensor_start_indices = compute_subtensor_start_indices( x_subtensor_shapes) x_subtensor_stop_indices = compute_subtensor_stop_indices( x_subtensor_shapes) x_start_index = torch.from_numpy( x_subtensor_start_indices[_slice].squeeze()) x_stop_index = torch.from_numpy( x_subtensor_stop_indices[_slice].squeeze()) y_subtensor_shapes = compute_subtensor_shapes_balanced( global_output_tensor_structure, self.P_x.shape) y_subtensor_start_indices = compute_subtensor_start_indices( y_subtensor_shapes) y_subtensor_stop_indices = compute_subtensor_stop_indices( y_subtensor_shapes) y_start_index = torch.from_numpy( y_subtensor_start_indices[_slice].squeeze()) y_stop_index = torch.from_numpy( y_subtensor_stop_indices[_slice].squeeze()) x_start_index = self._compute_needed_start( y_start_index, global_input_tensor_structure.shape, global_output_tensor_structure.shape, self.scale_factor, self.mode, self.align_corners) x_stop_index = self._compute_needed_stop( y_stop_index - 1, global_input_tensor_structure.shape, global_output_tensor_structure.shape, self.scale_factor, self.mode, self.align_corners) self.interp_layer = Interpolate(x_start_index, x_stop_index, global_input_tensor_structure.shape, y_start_index, y_stop_index, global_output_tensor_structure.shape, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
def __init__(self, P_x, P_y, P_w, in_channels, out_channels, kernel_size, stride=1, padding=0, padding_mode='zeros', dilation=1, groups=1, bias=True, *args, **kwargs): super(DistributedChannelConvBase, self).__init__() # P_x is 1 x P_ci x 1 x ... x 1 self.P_x = P_x # P_y is 1 x P_co x 1 x ... x 1 self.P_y = P_y # P_w is P_co x P_ci x 1 x ... x 1 self.P_w = P_w # Even inactive workers need some partition union P_union = self._distdl_backend.Partition() if not (self.P_x.active or self.P_y.active or self.P_w.active): return self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = self._expand_parameter(kernel_size) self.stride = self._expand_parameter(stride) self.padding = self._expand_parameter(padding) self.padding_mode = padding_mode self.dilation = self._expand_parameter(dilation) self.groups = groups self.use_bias = bias # This guarantees that P_union rank 0 has the kernel size, stride, # padding, and dilation factors P_union_temp = P_w.create_partition_union(P_x) P_union = P_union_temp.create_partition_union(P_y) # Ensure that all workers have the full size and structure of P_w P_w_shape = None if P_union.rank == 0: P_w_shape = np.array(P_w.shape, dtype=np.int) P_w_shape = P_union.broadcast_data(P_w_shape, root=0) # Release the temporary resources P_union_temp.deactivate() P_union.deactivate() P_co = P_w_shape[0] P_ci = P_w_shape[1] P_channels = [P_co, P_ci] # Ensure that P_x and P_w are correctly aligned. We also produce a # new P_x that is shaped like 1 x P_ci x 1 x ... x 1, to assist with # broadcasts. P_x_new_shape = [] if self.P_x.active: if (np.any(P_x.shape[2:] != P_w_shape[2:])): raise ValueError( "Spatial components of P_x and P_w must match.") if (np.any(P_x.shape[2:] != np.ones(len(P_x.shape[2:])))): raise ValueError( "Spatial components of P_x must be 1 x ... x 1.") if P_w_shape[1] != P_x.shape[1]: raise ValueError( "Index 2 of P_w dimension must match input channel partition." ) P_x_new_shape = list(P_x.shape) P_x_new_shape.insert(1, 1) # Currently a hack, removing the batch dimension because P_w does # not have one. This is OK because we assume there are no partitions # in the batch dimension. P_x_new_shape = np.asarray(P_x_new_shape[1:], dtype=int) # For the purposes of this layer, we re-cast P_x to have the extra # dimension. This has no impact outside of the layer or on the results. self.P_x = self.P_x.create_cartesian_topology_partition(P_x_new_shape) # Ensure that P_y and P_w are correctly aligned. We also produce a # new P_y that is shaped like P_co x 1 x 1 x ... x 1, to assist with # broadcasts. P_y_new_shape = [] if self.P_y.active: if (np.any(P_y.shape[2:] != P_w_shape[2:])): raise ValueError( "Spatial components of P_y and P_w must match.") if (np.any(P_y.shape[2:] != np.ones(len(P_y.shape[2:])))): raise ValueError( "Spatial components of P_y must be 1 x ... x 1.") if P_w_shape[0] != P_y.shape[1]: raise ValueError( "Index 1 of P_w dimension must match output channel partition." ) P_y_new_shape = list(P_y.shape) P_y_new_shape.insert(2, 1) # Currently a hack, removing the batch dimension because P_w does # not have one. This is OK because we assume there are no partitions # in the batch dimension. P_y_new_shape = np.asarray(P_y_new_shape[1:], dtype=int) # For the purposes of this layer, we re-cast P_x to have the extra # dimension. This has no impact outside of the layer or on the results. self.P_y = self.P_y.create_cartesian_topology_partition(P_y_new_shape) self.serial = self.P_w.size == 1 if self.serial: self.conv_layer = self.TorchConvType( in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, padding_mode=self.padding_mode, dilation=self.dilation, groups=self.groups, bias=self.use_bias) self.weight = self.conv_layer.weight self.bias = self.conv_layer.bias return # Flag if the global bias is set self.global_bias = bias # Flags if current worker stores (part of) the bias locally. self.stores_bias = False if self.P_w.active: # Let the P_co column store the bias if it is to be used self.stores_bias = self.P_w.index[1] == 0 and self.use_bias # Correct the input arguments based on local properties # This ensures that the in and out channels are correctly shared. local_co, local_ci = compute_subshape(P_channels, P_w.index[0:2], [out_channels, in_channels]) self.conv_layer = self.TorchConvType( in_channels=local_ci, out_channels=local_co, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, padding_mode=self.padding_mode, dilation=self.dilation, groups=groups, bias=self.stores_bias) # Workers in P_w alias the conv layer to get their weight and perhaps # biases. Every other worker doesn't have a weight or bias. if self.P_w.active: self.weight = self.conv_layer.weight if self.stores_bias: self.bias = self.conv_layer.bias else: if self.use_bias: self.register_buffer('bias', zero_volume_tensor()) else: self.register_buffer('bias', None) else: self.register_buffer('weight', zero_volume_tensor()) if self.use_bias: self.register_buffer('bias', zero_volume_tensor()) else: self.register_buffer('bias', None) # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_tensor_structure = TensorStructure() self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True) self.y_sum_reduce = SumReduce(self.P_w, self.P_y, preserve_batch=True)
def __init__(self, P_x, in_channels, out_channels, kernel_size, stride=1, padding=0, padding_mode='zeros', dilation=1, groups=1, bias=True, buffer_manager=None): super(DistributedFeatureConvBase, self).__init__() # P_x is 1 x 1 x P_d-1 x ... x P_0 self.P_x = P_x # Back-end specific buffer manager for economic buffer allocation if buffer_manager is None: buffer_manager = self._distdl_backend.BufferManager() elif type(buffer_manager) is not self._distdl_backend.BufferManager: raise ValueError("Buffer manager type does not match backend.") self.buffer_manager = buffer_manager if not self.P_x.active: return self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = self._expand_parameter(kernel_size) self.stride = self._expand_parameter(stride) self.padding = self._expand_parameter(padding) self.padding_mode = padding_mode self.dilation = self._expand_parameter(dilation) self.groups = groups self.use_bias = bias self.serial = self.P_x.size == 1 if self.serial: self.conv_layer = self.TorchConvType( in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, padding_mode=self.padding_mode, dilation=self.dilation, groups=self.groups, bias=self.use_bias) self.weight = self.conv_layer.weight self.bias = self.conv_layer.bias else: self.conv_layer = self.TorchConvType(in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=0, padding_mode='zeros', dilation=self.dilation, groups=groups, bias=bias) if self.serial: return dims = len(self.P_x.shape) # We will be using global padding to compute local padding, # so expand it to a numpy array global_padding = np.pad(self.padding, pad_width=(dims - len(self.padding), 0), mode='constant', constant_values=0) self.global_padding = global_padding pad_left_right = self.global_padding.reshape((dims, 1)) + np.zeros( (dims, 2), dtype=np.int) self.local_padding = self._compute_local_padding(pad_left_right) # Weights and biases partition P_wb = self.P_x.create_partition_inclusive([0]) self.P_wb_cart = P_wb.create_cartesian_topology_partition([1]) # Release temporary resources P_wb.deactivate() # We want only the root rank of the broadcast to have a weight and a # bias parameter. Every other rank gets a zero-volume tensor. if self.P_wb_cart.active: self.weight = torch.nn.Parameter(self.conv_layer.weight.detach()) if self.conv_layer.bias is not None: self.bias = torch.nn.Parameter(self.conv_layer.bias.detach()) else: self.register_buffer('bias', None) else: self.register_buffer('weight', zero_volume_tensor()) if self.conv_layer.bias is not None: self.register_buffer('bias', zero_volume_tensor()) else: self.register_buffer('bias', None) self.weight.requires_grad = self.conv_layer.weight.requires_grad if self.conv_layer.bias is not None: self.bias.requires_grad = self.conv_layer.bias.requires_grad # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2 new_weight = self.conv_layer.weight.detach() * 0 new_weight.requires_grad = self.conv_layer.weight.requires_grad del self.conv_layer.weight self.conv_layer.weight = new_weight if self.conv_layer.bias is not None: new_bias = self.conv_layer.bias.detach() * 0 new_bias.requires_grad = self.conv_layer.bias.requires_grad del self.conv_layer.bias self.conv_layer.bias = new_bias self.w_broadcast = Broadcast(self.P_wb_cart, self.P_x, preserve_batch=False) if self.conv_layer.bias is not None: self.b_broadcast = Broadcast(self.P_wb_cart, self.P_x, preserve_batch=False) # We need to be able to remove some data from the input to the conv # layer. self.needed_slices = None # For the halo layer we also defer construction, so that we can have # the halo shape for the input. The halo will allocate its own # buffers, but it needs this information at construction to be able # to do this in the pre-forward hook. self.halo_layer = None # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_tensor_structure = TensorStructure()
def _distdl_module_setup(self, input): r"""Distributed (feature) convolution module setup function. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ self._distdl_is_setup = True self._input_tensor_structure = TensorStructure(input[0]) if not self.P_x.active: return if self.serial: return # Compute global and local shapes with padding x_global_structure = \ self._distdl_backend.assemble_global_tensor_structure(input[0], self.P_x) x_local_structure = TensorStructure(input[0]) x_global_shape = x_global_structure.shape x_local_shape = x_local_structure.shape x_global_shape_after_pad = x_global_shape + 2 * self.global_padding x_local_shape_after_pad = x_local_shape + np.sum( self.local_padding, axis=1, keepdims=False) x_local_structure_after_pad = TensorStructure(input[0]) x_local_structure_after_pad.shape = x_local_shape_after_pad # We need to compute the halos with respect to the explicit padding. # So, we assume the padding is already added, then compute the halo regions. compute_subtensor_shapes_unbalanced = \ self._distdl_backend.tensor_decomposition.compute_subtensor_shapes_unbalanced subtensor_shapes = \ compute_subtensor_shapes_unbalanced(x_local_structure_after_pad, self.P_x) # Using that information, we can get there rest of the halo information exchange_info = self._compute_exchange_info( x_global_shape_after_pad, self.kernel_size, self.stride, self._expand_parameter(0), self.dilation, self.P_x.active, self.P_x.shape, self.P_x.index, subtensor_shapes=subtensor_shapes) halo_shape = exchange_info[0] recv_buffer_shape = exchange_info[1] send_buffer_shape = exchange_info[2] needed_ranges = exchange_info[3] self.halo_shape = halo_shape # We can also set up part of the halo layer. self.halo_layer = HaloExchange(self.P_x, halo_shape, recv_buffer_shape, send_buffer_shape, buffer_manager=self.buffer_manager) # We have to select out the "unused" entries. Sometimes there can # be "negative" halos. self.needed_slices = assemble_slices(needed_ranges[:, 0], needed_ranges[:, 1])
class DistributedTranspose(Module): r"""A distributed transpose layer. This class provides the user interface to the transpose distributed data movement primitive. Implementation details are back-end specific. The Transpose algorithm performs a transpose, shuffle, or generalized all-to-all from a tensor partitioned with by `P_x` to a new tensor partitioned with `P_y`. The values of the tensor do not change. Only the distribution of the tensor over the workers changes. If ``P_x`` and ``P_y`` are exactly equal, then no data movement occurs. For input and output tensors that have a batch dimension, the batch dimension needs to be preserved. If a tensor does not have a batch dimension, we should not preserve that for zero-volume outputs. The `preserve_batch` option controls this. Parameters ---------- P_x : Partition of input tensor. P_y : Partition of output tensor. preserve_batch : bool, optional Indicates if batch size should be preserved for zero-volume outputs. buffer_manager : optional External manager for communication buffers """ def __init__(self, P_x, P_y, preserve_batch=True, buffer_manager=None): super(DistributedTranspose, self).__init__() # Global structure of the input tensor, assembled when layer is called self.global_input_tensor_structure = TensorStructure() # Local input and output tensor structures, defined when layer is called self.input_tensor_structure = TensorStructure() self.output_tensor_structure = TensorStructure() # Partition of input tensor. self.P_x = P_x # Partition of output tensor. self.P_y = P_y # Indicates if batch size should be preserved for zero-volume outputs. self.preserve_batch = preserve_batch # List of meta data describing copies of subvolumes of input tensor # out of the current worker self.P_x_to_y_overlaps = [] # List of meta data describing copies of subvolumes of output tensor # into the current worker self.P_y_to_x_overlaps = [] # Back-end specific buffer manager for economic buffer allocation if buffer_manager is None: buffer_manager = self._distdl_backend.BufferManager() elif type(buffer_manager) is not self._distdl_backend.BufferManager: raise ValueError("Buffer manager type does not match backend.") self.buffer_manager = buffer_manager # List of buffers for copying data to other workers self.P_x_to_y_buffers = None # List of buffers for copying data from other workers self.P_y_to_x_buffers = None # Variables for tracking input changes and buffer construction self._distdl_is_setup = False self._input_tensor_structure = TensorStructure() # Otherwise, we need the union of the input and output partitions # so that data can be copied across them. P_union = self._distdl_backend.Partition() if P_x.active or P_y.active: P_union = P_x.create_partition_union(P_y) self.P_union = P_union # Setup these variables incase the current worker is inactive in # the union. self.P_x_shape = None self.P_y_shape = None self.union_indices = None if not P_union.active: return # All active workers need the shapes of both partitions so that buffer # sizes and subtensor overlaps can be computed. data = None if self.P_x.active: data = self.P_x.shape self.P_x_shape = self.P_union.broadcast_data(data, P_data=self.P_x) data = None if self.P_y.active: data = self.P_y.shape self.P_y_shape = self.P_union.broadcast_data(data, P_data=self.P_y) if len(self.P_x_shape) != len(self.P_y_shape): raise ValueError( "Input and output partition must be same dimension.") # Share the two indices with every worker in the union. The first # column of data contains the source index and the second contains # the destination index. data = np.array([P_x.rank if P_x.active else -1], dtype=np.int) self.P_x_ranks = P_union.allgather_data(data) data = np.array([P_y.rank if P_y.active else -1], dtype=np.int) self.P_y_ranks = P_union.allgather_data(data) # Get some types and functions from the back-end self.allocate_transpose_buffers = self._distdl_backend.transpose.allocate_transpose_buffers def _distdl_module_setup(self, input): r"""Transpose module setup function. Constructs the necessary buffers and meta information about outbound and inbound copies to each worker. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ self._distdl_is_setup = True self._input_tensor_structure.fill_from_tensor(input[0]) # If we are not an active worker, do nothing. if not self.P_union.active: return self.input_tensor_structure = TensorStructure(input[0]) self.global_input_tensor_structure = \ self._distdl_backend.assemble_global_tensor_structure(self.input_tensor_structure, self.P_x, self.P_union) x_global_shape = self.global_input_tensor_structure.shape if self.P_y.active: self.output_tensor_structure.shape = compute_subshape( self.P_y.shape, self.P_y.index, x_global_shape) tensor_dim = len(x_global_shape) if len(self.P_x_shape) != tensor_dim: raise ValueError( f"Input partition mush have same dimension " f"({len(self.P_x_shape)}) as input tensor rank ({tensor_dim})." ) if len(self.P_y_shape) != tensor_dim: raise ValueError( f"Output partition mush have same dimension " f"({len(self.P_y_shape)}) as input tensor rank ({tensor_dim})." ) if 1 in x_global_shape[x_global_shape != self.P_y_shape]: raise ValueError( f"Input tensor must not be size 1 " f"({x_global_shape}) in a dimension where " f"output partition is other than 1 ({self.P_y_shape}).") # Get the collective input lengths and origins. This may be load # balanced or it may not be. Therefore we will always assume is is # not load balanced and just build the subshape tensor manually. # This output is needed everywhere so it goes to P_union. compute_subtensor_shapes_unbalanced = \ self._distdl_backend.tensor_decomposition.compute_subtensor_shapes_unbalanced x_subtensor_shapes = compute_subtensor_shapes_unbalanced( self.input_tensor_structure, self.P_x, self.P_union) # Get the collective output lengths and origins. This will always be # load balanced, so we can infer the subshape tensor from the global # tensor shape and the shape of P_y. At this point, every worker in # P_union has both of these pieces of information, so we can build it # with no communication. y_subtensor_shapes = compute_subtensor_shapes_balanced( self.global_input_tensor_structure, self.P_y_shape) # Given all subtensor shapes, we can compute the start and stop indices # for each partition. x_subtensor_start_indices = compute_subtensor_start_indices( x_subtensor_shapes) x_subtensor_stop_indices = compute_subtensor_stop_indices( x_subtensor_shapes) y_subtensor_start_indices = compute_subtensor_start_indices( y_subtensor_shapes) y_subtensor_stop_indices = compute_subtensor_stop_indices( y_subtensor_shapes) # We only need to move data to the output partition if we actually # have input data. It is possible to have both input and output data, # either input or output data, or neither. Hence the active guard. if self.P_x.active: x_slice = tuple([slice(i, i + 1) for i in self.P_x.index] + [slice(None)]) x_start_index = x_subtensor_start_indices[x_slice].squeeze() x_stop_index = x_subtensor_stop_indices[x_slice].squeeze() # Compute our overlaps for each output subpartition. for rank, P_y_index in enumerate(range_index(self.P_y_shape)): y_slice = tuple([slice(i, i + 1) for i in P_y_index] + [slice(None)]) y_start_index = y_subtensor_start_indices[y_slice].squeeze() y_stop_index = y_subtensor_stop_indices[y_slice].squeeze() sl = compute_subtensor_intersection_slice( x_start_index, x_stop_index, y_start_index, y_stop_index) if sl is not None: sh = compute_nd_slice_shape(sl) # If it is a self-copy, mark it so we don't have to create # a potentially large buffer if self.P_y.active and np.all(P_y_index == self.P_y.index): partner = "self" # Otherwise, reverse the mapping to get the output # partner's rank in the common partition. else: partner = np.where(self.P_y_ranks == rank)[0][0] self.P_x_to_y_overlaps.append((sl, sh, partner)) else: self.P_x_to_y_overlaps.append((None, None, None)) # We only need to obtain data from the input partition if we actually # have output data. if self.P_y.active: y_slice = tuple([slice(i, i + 1) for i in self.P_y.index] + [slice(None)]) y_start_index = y_subtensor_start_indices[y_slice].squeeze() y_stop_index = y_subtensor_stop_indices[y_slice].squeeze() # Compute our overlaps for each input subpartition. for rank, P_x_index in enumerate(range_index(self.P_x_shape)): x_slice = tuple([slice(i, i + 1) for i in P_x_index] + [slice(None)]) x_start_index = x_subtensor_start_indices[x_slice].squeeze() x_stop_index = x_subtensor_stop_indices[x_slice].squeeze() sl = compute_subtensor_intersection_slice( y_start_index, y_stop_index, x_start_index, x_stop_index) if sl is not None: sh = compute_nd_slice_shape(sl) # If it is a self-copy, mark it so we don't have to create # a potentially large buffer if self.P_x.active and np.all(P_x_index == self.P_x.index): partner = "self" # Otherwise, reverse the mapping to get the output # partner's rank in the common partition. else: partner = np.where(self.P_x_ranks == rank)[0][0] self.P_y_to_x_overlaps.append((sl, sh, partner)) else: self.P_y_to_x_overlaps.append((None, None, None)) buffs = self.allocate_transpose_buffers( self.buffer_manager, self.P_x_to_y_overlaps, self.P_y_to_x_overlaps, self.global_input_tensor_structure.dtype) self.P_x_to_y_buffers = buffs[0] self.P_y_to_x_buffers = buffs[1] def _distdl_module_teardown(self, input): r"""Transpose module teardown function. Deallocates buffers safely. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ # Reset all of the buffers and communication objects self.P_x_to_y_overlaps = [] self.P_y_to_x_overlaps = [] self.P_x_to_y_buffers = None self.P_y_to_x_buffers = None # Reset any info about the input self._distdl_is_setup = False self._input_tensor_structure = TensorStructure() def _distdl_input_changed(self, input): r"""Determine if the structure of inputs has changed. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ new_tensor_structure = TensorStructure(input[0]) return self._input_tensor_structure != new_tensor_structure def forward(self, input): """Forward function interface. Parameters ---------- input : Input tensor to be broadcast. """ Function = self._distdl_backend.functional.transpose.DistributedTransposeFunction # If this is an identity operation (no communication necessary), # simply return a clone of the input. # If this worker is not active for the input or output, then the input # should be a zero-volume tensor, and the output should be the same. if not (self.P_x.active or self.P_y.active): return input.clone() return Function.apply( input, self.P_union, self.global_input_tensor_structure, self.input_tensor_structure, self.output_tensor_structure, self.P_x, self.P_x_to_y_overlaps, self.P_x_to_y_buffers, self.P_y, self.P_y_to_x_overlaps, self.P_y_to_x_buffers, self.preserve_batch)
def _distdl_module_setup(self, input): r"""Transpose module setup function. Constructs the necessary buffers and meta information about outbound and inbound copies to each worker. This function is called every time something changes in the input tensor structure. It should not be called manually. Parameters ---------- input : Tuple of forward inputs. See `torch.nn.Module.register_forward_pre_hook` for more details. """ self._distdl_is_setup = True self._input_tensor_structure.fill_from_tensor(input[0]) # If we are not an active worker, do nothing. if not self.P_union.active: return self.input_tensor_structure = TensorStructure(input[0]) self.global_input_tensor_structure = \ self._distdl_backend.assemble_global_tensor_structure(self.input_tensor_structure, self.P_x, self.P_union) x_global_shape = self.global_input_tensor_structure.shape if self.P_y.active: self.output_tensor_structure.shape = compute_subshape( self.P_y.shape, self.P_y.index, x_global_shape) tensor_dim = len(x_global_shape) if len(self.P_x_shape) != tensor_dim: raise ValueError( f"Input partition mush have same dimension " f"({len(self.P_x_shape)}) as input tensor rank ({tensor_dim})." ) if len(self.P_y_shape) != tensor_dim: raise ValueError( f"Output partition mush have same dimension " f"({len(self.P_y_shape)}) as input tensor rank ({tensor_dim})." ) if 1 in x_global_shape[x_global_shape != self.P_y_shape]: raise ValueError( f"Input tensor must not be size 1 " f"({x_global_shape}) in a dimension where " f"output partition is other than 1 ({self.P_y_shape}).") # Get the collective input lengths and origins. This may be load # balanced or it may not be. Therefore we will always assume is is # not load balanced and just build the subshape tensor manually. # This output is needed everywhere so it goes to P_union. compute_subtensor_shapes_unbalanced = \ self._distdl_backend.tensor_decomposition.compute_subtensor_shapes_unbalanced x_subtensor_shapes = compute_subtensor_shapes_unbalanced( self.input_tensor_structure, self.P_x, self.P_union) # Get the collective output lengths and origins. This will always be # load balanced, so we can infer the subshape tensor from the global # tensor shape and the shape of P_y. At this point, every worker in # P_union has both of these pieces of information, so we can build it # with no communication. y_subtensor_shapes = compute_subtensor_shapes_balanced( self.global_input_tensor_structure, self.P_y_shape) # Given all subtensor shapes, we can compute the start and stop indices # for each partition. x_subtensor_start_indices = compute_subtensor_start_indices( x_subtensor_shapes) x_subtensor_stop_indices = compute_subtensor_stop_indices( x_subtensor_shapes) y_subtensor_start_indices = compute_subtensor_start_indices( y_subtensor_shapes) y_subtensor_stop_indices = compute_subtensor_stop_indices( y_subtensor_shapes) # We only need to move data to the output partition if we actually # have input data. It is possible to have both input and output data, # either input or output data, or neither. Hence the active guard. if self.P_x.active: x_slice = tuple([slice(i, i + 1) for i in self.P_x.index] + [slice(None)]) x_start_index = x_subtensor_start_indices[x_slice].squeeze() x_stop_index = x_subtensor_stop_indices[x_slice].squeeze() # Compute our overlaps for each output subpartition. for rank, P_y_index in enumerate(range_index(self.P_y_shape)): y_slice = tuple([slice(i, i + 1) for i in P_y_index] + [slice(None)]) y_start_index = y_subtensor_start_indices[y_slice].squeeze() y_stop_index = y_subtensor_stop_indices[y_slice].squeeze() sl = compute_subtensor_intersection_slice( x_start_index, x_stop_index, y_start_index, y_stop_index) if sl is not None: sh = compute_nd_slice_shape(sl) # If it is a self-copy, mark it so we don't have to create # a potentially large buffer if self.P_y.active and np.all(P_y_index == self.P_y.index): partner = "self" # Otherwise, reverse the mapping to get the output # partner's rank in the common partition. else: partner = np.where(self.P_y_ranks == rank)[0][0] self.P_x_to_y_overlaps.append((sl, sh, partner)) else: self.P_x_to_y_overlaps.append((None, None, None)) # We only need to obtain data from the input partition if we actually # have output data. if self.P_y.active: y_slice = tuple([slice(i, i + 1) for i in self.P_y.index] + [slice(None)]) y_start_index = y_subtensor_start_indices[y_slice].squeeze() y_stop_index = y_subtensor_stop_indices[y_slice].squeeze() # Compute our overlaps for each input subpartition. for rank, P_x_index in enumerate(range_index(self.P_x_shape)): x_slice = tuple([slice(i, i + 1) for i in P_x_index] + [slice(None)]) x_start_index = x_subtensor_start_indices[x_slice].squeeze() x_stop_index = x_subtensor_stop_indices[x_slice].squeeze() sl = compute_subtensor_intersection_slice( y_start_index, y_stop_index, x_start_index, x_stop_index) if sl is not None: sh = compute_nd_slice_shape(sl) # If it is a self-copy, mark it so we don't have to create # a potentially large buffer if self.P_x.active and np.all(P_x_index == self.P_x.index): partner = "self" # Otherwise, reverse the mapping to get the output # partner's rank in the common partition. else: partner = np.where(self.P_x_ranks == rank)[0][0] self.P_y_to_x_overlaps.append((sl, sh, partner)) else: self.P_y_to_x_overlaps.append((None, None, None)) buffs = self.allocate_transpose_buffers( self.buffer_manager, self.P_x_to_y_overlaps, self.P_y_to_x_overlaps, self.global_input_tensor_structure.dtype) self.P_x_to_y_buffers = buffs[0] self.P_y_to_x_buffers = buffs[1]