예제 #1
0
    def _distdl_module_setup(self, input):
        r"""AllSumReduce module setup function.

        Constructs the necessary partition functions to implement the above
        described reduction pattern.  This function performs collective
        communication across the input and output partitions.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        if not (self.P_x.active):
            return

        # If it is not an identity, we need actual Partitions to do the work.
        if not self.identity:

            self.P_allreduce = self.P_x.create_allreduction_partition(
                self.axes_reduce)

            self.input_tensor_structure = TensorStructure(input[0])
            self.output_tensor_structure = self.input_tensor_structure

        self._distdl_is_setup = True
        self._input_tensor_structure = TensorStructure(input[0])
예제 #2
0
    def _distdl_module_teardown(self, input):
        r"""Broadcast module teardown function.

        Nullifies the necessary partition functions.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        # Reset all of the buffers and communication objects
        self.P_send.deactivate()
        self.P_recv.deactivate()

        # Reset any data stored about the tensor
        self.input_tensor_structure = TensorStructure()
        self.output_tensor_structure = TensorStructure()

        # Reset any info about the input
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()
예제 #3
0
    def _distdl_module_teardown(self, input):
        r"""Transpose module teardown function.

        Deallocates buffers safely.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        # Reset all of the buffers and communication objects
        self.P_x_to_y_overlaps = []
        self.P_y_to_x_overlaps = []

        self.P_x_to_y_buffers = None
        self.P_y_to_x_buffers = None

        # Reset any info about the input
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()
예제 #4
0
    def __init__(self,
                 P_x,
                 buffer_manager=None,
                 size=None,
                 scale_factor=None,
                 mode='linear',
                 align_corners=False):

        super(DistributedUpsample, self).__init__()

        if mode == 'cubic':
            raise NotImplementedError(
                'Cubic interpolation is not implemented.')

        if size is None and scale_factor is None:
            raise ValueError("One of `size` or `scale_factor` must be set.")

        if size is not None and scale_factor is not None:
            raise ValueError(
                "Only one of `size` or `scale_factor` may be set.")

        # P_x is 1 x 1 x P_d-1 x ... x P_0
        self.P_x = P_x

        # Back-end specific buffer manager for economic buffer allocation
        if buffer_manager is None:
            buffer_manager = self._distdl_backend.BufferManager()
        elif type(buffer_manager) is not self._distdl_backend.BufferManager:
            raise ValueError("Buffer manager type does not match backend.")
        self.buffer_manager = buffer_manager

        if not self.P_x.active:
            return

        # Do this before checking serial so that the layer works properly
        # in the serial case
        # self.pool_layer = self.TorchPoolType(*args, **kwargs)

        self.mode = mode
        self.align_corners = align_corners

        self.size = size
        self.scale_factor = scale_factor

        # Local input and output tensor structures, defined when layer is called
        self.input_tensor_structure = TensorStructure()
        self.output_tensor_structure = TensorStructure()

        # We need the actual sizes to determine the interpolation layer
        self.interp_layer = None

        # For the halo layer we also defer construction, so that we can have
        # the halo shape for the input.  The halo will allocate its own
        # buffers, but it needs this information at construction to be able
        # to do this in the pre-forward hook.
        self.halo_layer = None

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()
예제 #5
0
    def __init__(self,
                 P_x,
                 P_y,
                 transpose_src=False,
                 transpose_dest=False,
                 preserve_batch=True):

        super(Broadcast, self).__init__()

        # Partition of input tensor.
        self.P_x = P_x

        # Partition of output tensor.
        self.P_y = P_y

        # Transpose the input partition prior to the broadcast.
        self.transpose_src = transpose_src

        # Transpose the output partition prior to the broadcast.
        self.transpose_dest = transpose_dest

        # Indicates if batch size should be preserved for zero-volume outputs.
        self.preserve_batch = preserve_batch

        # Indicates if broadcast requires any data movement.
        self.identity = False

        # Partition for sharing copy of local data.
        self.P_send = self._distdl_backend.Partition()

        # Partition for receiving copy of local data.
        self.P_recv = self._distdl_backend.Partition()

        # Other info needed by the functions

        # Structure of the input tensor (shape, dtype, requires_grad, etc).
        self.input_tensor_structure = TensorStructure()
        # Structure of the output tensor (shape, dtype, requires_grad, etc).
        self.output_tensor_structure = TensorStructure()

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()

        # The identity case is if the partitions are of size 1,
        # or they are the same partition and neither is tranposed,
        # or they are the same partition and both are transposed.
        if self.P_x == self.P_y:
            if self.P_x.size == 1:
                self.identity = True
            elif (self.transpose_dest and self.transpose_src) or \
                 (not self.transpose_dest and not self.transpose_src):
                self.identity = True
예제 #6
0
    def __init__(self,
                 P_x,
                 halo_shape,
                 recv_buffer_shape,
                 send_buffer_shape,
                 buffer_manager=None):

        super(HaloExchange, self).__init__()

        self.P_x = P_x
        self.halo_shape = halo_shape
        self.recv_buffer_shape = recv_buffer_shape
        self.send_buffer_shape = send_buffer_shape

        self.neighbor_ranks = None
        if self.P_x.active:
            self.neighbor_ranks = self.P_x.neighbor_ranks(self.P_x.rank)

        self.slices = None
        self.buffers = None

        # Back-end specific buffer manager for economic buffer allocation
        if buffer_manager is None:
            buffer_manager = self._distdl_backend.BufferManager()
        elif type(buffer_manager) is not self._distdl_backend.BufferManager:
            raise ValueError("Buffer manager type does not match backend.")
        self.buffer_manager = buffer_manager

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()

        # Get some types and functions from the back-end
        self.allocate_halo_exchange_buffers = self._distdl_backend.halo_exchange.allocate_halo_exchange_buffers
예제 #7
0
파일: loss.py 프로젝트: rhewett/distdl
    def _distdl_module_setup(self, input):
        r"""Distributed loss setup function.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        if not self.P_x.active:
            return

        # If the reduction mode is "mean", we need the total size of the
        # global input tensor.
        if self.reduction == "mean":
            N_local = np.prod(input[0].shape)
            self.normalization_factor = int(
                self.P_x.allreduce_data(np.asarray(N_local)))

        self._distdl_is_setup = True
        self._input_tensor_structure = TensorStructure(input[0])
예제 #8
0
파일: loss.py 프로젝트: rhewett/distdl
    def _distdl_module_setup(self, input):
        r"""Distributed KL Divergence loss setup function.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        if not self.P_x.active:
            return

        # If the reduction mode is "mean", we need the total size of the
        # global input tensor.
        if self.reduction == "mean":
            N_local = np.prod(input[0].shape)
            self.normalization_factor = int(
                self.P_x.allreduce_data(np.asarray(N_local)))
        # If the reduction mode is "batchmean", we need the batch size, which
        # means we only add the shapes along the batch axis.
        elif self.reduction == "batchmean":
            if all(coord == 0 for coord in self.P_x.index[1:]):
                N_local = input[0].shape[0]
            else:
                N_local = 0
            self.normalization_factor = int(
                self.P_x.allreduce_data(np.asarray(N_local)))

        self._distdl_is_setup = True
        self._input_tensor_structure = TensorStructure(input[0])
예제 #9
0
    def _distdl_module_teardown(self, input):

        # Reset all of the buffers and communication objects
        self.slices = None
        self.buffers = None

        # Reset any info about the input
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()
예제 #10
0
    def __init__(self, P_x, axes_reduce=None, axes_keep=None):

        super(AllSumReduce, self).__init__()

        # Partition of input and output tensor.
        self.P_x = P_x

        # Partition dimensions along which the all-reduction takes place.
        # While we compute both terms, `axes_reduce` is used internally.
        if axes_reduce is None and axes_keep is None:
            raise ValueError(
                "One of `axes_reduce` or `axes_keep` must be specified.")
        elif axes_reduce is not None and axes_keep is not None:
            raise ValueError(
                "Only one of `axes_reduce` or `axes_keep` may be specified.")
        elif axes_reduce is not None:
            self.axes_reduce = axes_reduce
            self.axes_keep = [
                d for d in range(P_x.dim) if d not in axes_reduce
            ]
        elif axes_keep is not None:
            self.axes_reduce = [
                d for d in range(P_x.dim) if d not in axes_keep
            ]
            self.axes_keep = axes_keep

        # Indicates if broadcast requires any data movement.
        self.identity = False

        # Partition for performing all-reduction.
        self.P_allreduce = self._distdl_backend.Partition()

        # Structure of the input tensor (shape, dtype, requires_grad, etc).
        self.input_tensor_structure = TensorStructure()
        # Structure of the output tensor (shape, dtype, requires_grad, etc).
        self.output_tensor_structure = TensorStructure()

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()

        # The identity case is if the partition is of size 1,
        if self.P_x.size == 1:
            self.identity = True
예제 #11
0
def assemble_global_tensor_structure(local_tensor_structure, P_in, P_out=None):

    global_tensor_structure = TensorStructure()
    global_tensor_shape = None
    intID_dtype = None
    requires_grad_int = None

    if P_in.active:

        # Assemble the global shape
        global_tensor_shape = np.zeros(P_in.dim, dtype=np.int)
        for i in range(P_in.dim):

            keep = [False] * P_in.dim
            keep[i] = True

            P_sub = P_in.create_cartesian_subtopology_partition(keep)

            v0 = np.atleast_1d(int(local_tensor_structure.shape[i]))
            v1 = np.zeros(1, dtype=np.int)
            P_sub._comm.Allreduce(v0, v1, op=MPI.SUM)
            global_tensor_shape[i] = v1[0]

            # Free the subtopology resources
            P_sub.deactivate()

        # Get a communicable integer representing the dtype
        intID_dtype = torch_to_intID_dtype_dict[local_tensor_structure.dtype]
        intID_dtype = np.array([intID_dtype], dtype=np.int)

        requires_grad_int = np.array([-1], dtype=np.int)
        requires_grad_int[0] = 1 if local_tensor_structure.requires_grad else 0

        global_tensor_structure.shape = global_tensor_shape
        global_tensor_structure.dtype = local_tensor_structure.dtype
        global_tensor_structure.requires_grad = local_tensor_structure.requires_grad

    if P_out is not None and P_out.active:
        # Share the shape
        global_tensor_structure.shape = P_out.broadcast_data(
            global_tensor_shape, P_data=P_in)

        # Share the dtype
        intID_dtype = P_out.broadcast_data(intID_dtype, P_data=P_in)
        global_tensor_structure.dtype = intID_to_torch_dtype_dict[
            intID_dtype[0]]

        # Share the requires_grad status
        requires_grad_int = P_out.broadcast_data(requires_grad_int,
                                                 P_data=P_in)
        global_tensor_structure.requires_grad = bool(requires_grad_int[0])

    return global_tensor_structure
예제 #12
0
    def _distdl_module_setup(self, input):

        if self.P_x.active:
            x_local_shape = input[0].shape
            self.slices = self._assemble_slices(x_local_shape,
                                                self.recv_buffer_shape,
                                                self.send_buffer_shape)
            self.buffers = self.allocate_halo_exchange_buffers(
                self.buffer_manager, self.slices, self.recv_buffer_shape,
                self.send_buffer_shape, input[0].dtype)

        self._distdl_is_setup = True
        self._input_tensor_structure = TensorStructure(input[0])
예제 #13
0
    def _distdl_input_changed(self, input):
        r"""Determine if the structure of inputs has changed.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        new_tensor_structure = TensorStructure(input[0])

        return self._input_tensor_structure != new_tensor_structure
예제 #14
0
    def _distdl_module_setup(self, input):
        r"""Broadcast module setup function.

        Constructs the necessary partition functions to implement the above
        described broadcast pattern.  This function performs collective
        communication across the input and output partitions.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        if not (self.P_x.active or self.P_y.active):
            return

        # If it is not an identity, we need actual Partitions to do the work.
        if not self.identity:
            bcast_partitions = self.P_x.create_broadcast_partition_to(
                self.P_y, self.transpose_src, self.transpose_dest)
            self.P_send = bcast_partitions[0]
            self.P_recv = bcast_partitions[1]

            self.input_tensor_structure = TensorStructure(input[0])
            self.output_tensor_structure = \
                self._distdl_backend.broadcast_tensor_structure(self.input_tensor_structure,
                                                                self.P_send,
                                                                self.P_recv)

        self._distdl_is_setup = True
        self._input_tensor_structure = TensorStructure(input[0])
예제 #15
0
    def _distdl_module_teardown(self, input):
        r"""Distributed (channel) convolution module teardown function.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        # Reset any info about the input
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()
예제 #16
0
파일: loss.py 프로젝트: rhewett/distdl
    def _distdl_module_teardown(self, input):
        r"""Distributed loss teardown function.

        Nullifies the necessary partition functions.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        self.normalization_factor = 1

        # Reset any info about the input
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()
예제 #17
0
    def _distdl_module_setup(self, input):
        r"""Distributed (channel) convolution module setup function.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        # No setup is needed if the worker is not doing anything for this
        # layer.
        if not (self.P_x.active or self.P_y.active or self.P_w.active):
            return

        if self.serial:
            return

        self._distdl_is_setup = True
        self._input_tensor_structure = TensorStructure(input[0])
예제 #18
0
    def __init__(self, P_x, P_y, preserve_batch=True, buffer_manager=None):
        super(DistributedTranspose, self).__init__()

        # Global structure of the input tensor, assembled when layer is called
        self.global_input_tensor_structure = TensorStructure()

        # Local input and output tensor structures, defined when layer is called
        self.input_tensor_structure = TensorStructure()
        self.output_tensor_structure = TensorStructure()

        # Partition of input tensor.
        self.P_x = P_x

        # Partition of output tensor.
        self.P_y = P_y

        # Indicates if batch size should be preserved for zero-volume outputs.
        self.preserve_batch = preserve_batch

        # List of meta data describing copies of subvolumes of input tensor
        # out of the current worker
        self.P_x_to_y_overlaps = []

        # List of meta data describing copies of subvolumes of output tensor
        # into the current worker
        self.P_y_to_x_overlaps = []

        # Back-end specific buffer manager for economic buffer allocation
        if buffer_manager is None:
            buffer_manager = self._distdl_backend.BufferManager()
        elif type(buffer_manager) is not self._distdl_backend.BufferManager:
            raise ValueError("Buffer manager type does not match backend.")
        self.buffer_manager = buffer_manager

        # List of buffers for copying data to other workers
        self.P_x_to_y_buffers = None

        # List of buffers for copying data from other workers
        self.P_y_to_x_buffers = None

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()

        # Otherwise, we need the union of the input and output partitions
        # so that data can be copied across them.
        P_union = self._distdl_backend.Partition()
        if P_x.active or P_y.active:
            P_union = P_x.create_partition_union(P_y)
        self.P_union = P_union

        # Setup these variables incase the current worker is inactive in
        # the union.
        self.P_x_shape = None
        self.P_y_shape = None
        self.union_indices = None

        if not P_union.active:
            return

        # All active workers need the shapes of both partitions so that buffer
        # sizes and subtensor overlaps can be computed.
        data = None
        if self.P_x.active:
            data = self.P_x.shape
        self.P_x_shape = self.P_union.broadcast_data(data, P_data=self.P_x)

        data = None
        if self.P_y.active:
            data = self.P_y.shape
        self.P_y_shape = self.P_union.broadcast_data(data, P_data=self.P_y)

        if len(self.P_x_shape) != len(self.P_y_shape):
            raise ValueError(
                "Input and output partition must be same dimension.")

        # Share the two indices with every worker in the union.  The first
        # column of data contains the source index and the second contains
        # the destination index.
        data = np.array([P_x.rank if P_x.active else -1], dtype=np.int)
        self.P_x_ranks = P_union.allgather_data(data)

        data = np.array([P_y.rank if P_y.active else -1], dtype=np.int)
        self.P_y_ranks = P_union.allgather_data(data)

        # Get some types and functions from the back-end
        self.allocate_transpose_buffers = self._distdl_backend.transpose.allocate_transpose_buffers
예제 #19
0
def broadcast_tensor_structure(input_tensor_structure, P_send, P_recv):

    output_tensor_structure = TensorStructure()

    if not P_send.active and not P_recv.active:
        return output_tensor_structure

    requests = []

    if P_send.active:
        # Share the torch dtype code, converted to an int.
        intID_dtype = torch_to_intID_dtype_dict[input_tensor_structure.dtype]
        send_intID_dtype = np.array([intID_dtype], dtype=np.int)
        req = P_send._comm.Iallreduce(MPI.IN_PLACE,
                                      send_intID_dtype,
                                      op=MPI.MAX)
        requests.append(req)

        # Need to send non-Python types, so convert the boolean temporarily
        rg_int_send = np.array([-1], dtype=np.int)
        rg_int_send[0] = 1 if input_tensor_structure.requires_grad else 0
        req = P_send._comm.Iallreduce(MPI.IN_PLACE, rg_int_send, op=MPI.MAX)
        requests.append(req)

        # Sending processes know the shape, so they can send a copy of the
        # data.  We will ignore this variable later.
        send_tensor_dim = np.array([len(input_tensor_structure.shape)],
                                   dtype=np.int)
        req = P_send._comm.Iallreduce(MPI.IN_PLACE,
                                      send_tensor_dim,
                                      op=MPI.MAX)
        requests.append(req)

        # Similarly, sending processes know the tensor shape, so they can send
        # a copy of it, but we will not use that copy for our actual return
        # value.
        send_tensor_shape = np.array(input_tensor_structure.shape,
                                     dtype=np.int)
        req = P_send._comm.Iallreduce(MPI.IN_PLACE,
                                      send_tensor_shape,
                                      op=MPI.MAX)
        requests.append(req)

    # If the process is a receiving process, but doesn't already know the data
    # because it is the _same_ sending process, then we receive the results.
    # If it is a receiving process that sent data to a different set of
    # processes, we still have to complete the receive, even though later we
    # will not use that data.
    if (P_send != P_recv) and P_recv.active:

        # Everyone needs to receive these two values, but we don't need them
        # for future communication in this function so we can defer receiving
        # the data.
        recv_intID_dtype = np.array([-1], dtype=np.int)
        req = P_recv._comm.Iallreduce(MPI.IN_PLACE,
                                      recv_intID_dtype,
                                      op=MPI.MAX)
        requests.append(req)

        rg_int_recv = np.array([-1], dtype=np.int)
        req = P_recv._comm.Iallreduce(MPI.IN_PLACE, rg_int_recv, op=MPI.MAX)
        requests.append(req)

        # We need this value for the next communication, so we have to wait
        # for it to complete before moving on.
        recv_tensor_dim = np.array([-1], dtype=np.int)
        req = P_recv._comm.Iallreduce(MPI.IN_PLACE,
                                      recv_tensor_dim,
                                      op=MPI.MAX)
        req.Wait()

        recv_tensor_shape = np.zeros(recv_tensor_dim, dtype=np.int)
        recv_tensor_shape[:] = -1
        req = P_recv._comm.Iallreduce(MPI.IN_PLACE,
                                      recv_tensor_shape,
                                      op=MPI.MAX)
        requests.append(req)

    # Make sure all requests, including the final recv all reduce complete
    # before receiving processes can actually copy the data out.
    MPI.Request.Waitall(requests)

    # Wait until the communication is complete to set these values.  Only
    # receiving ranks that do not have the data originally should enter here.
    if P_recv.active and (P_send != P_recv):
        output_tensor_structure.shape = torch.Size(recv_tensor_shape)
        output_tensor_structure.dtype = intID_to_torch_dtype_dict[
            recv_intID_dtype[0]]
        output_tensor_structure.requires_grad = bool(rg_int_recv[0])

    elif P_send == P_recv:
        output_tensor_structure.shape = input_tensor_structure.shape
        output_tensor_structure.dtype = input_tensor_structure.dtype
        output_tensor_structure.requires_grad = input_tensor_structure.requires_grad

    # Finally, every active worker should have valid data.  Any sending rank
    # created it from input data.  Any receving _only_ rank used what it was
    # given.
    return output_tensor_structure
예제 #20
0
    def __init__(self, P_x, P_y, P_w,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 padding_mode='zeros',
                 dilation=1,
                 groups=1,
                 bias=True,
                 buffer_manager=None):

        super(DistributedGeneralConvBase, self).__init__()

        # P_x is 1    x P_ci x P_d-1 x ... x P_0
        self.P_x = P_x
        # P_y is 1    x P_co x P_d-1 x ... x P_0
        self.P_y = P_y
        # P_w is P_co x P_ci x P_d-1 x ... x P_0
        self.P_w = P_w

        # Back-end specific buffer manager for economic buffer allocation
        if buffer_manager is None:
            buffer_manager = self._distdl_backend.BufferManager()
        elif type(buffer_manager) is not self._distdl_backend.BufferManager:
            raise ValueError("Buffer manager type does not match backend.")
        self.buffer_manager = buffer_manager

        # Even inactive workers need some partition union
        self.P_union = self._distdl_backend.Partition()
        if not (self.P_x.active or
                self.P_y.active or
                self.P_w.active):
            return

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = self._expand_parameter(kernel_size)
        self.stride = self._expand_parameter(stride)
        self.padding = self._expand_parameter(padding)
        self.padding_mode = padding_mode
        self.dilation = self._expand_parameter(dilation)
        self.groups = groups
        self.use_bias = bias

        # This guarantees that P_union rank 0 has the kernel size, stride,
        # padding, and dilation factors
        P_union_temp = P_w.create_partition_union(P_x)
        self.P_union = P_union_temp.create_partition_union(P_y)

        # Release the temporary resources
        P_union_temp.deactivate()

        # Ensure that all workers have the full size and structure of P_w
        P_w_shape = None
        if self.P_union.rank == 0:
            P_w_shape = np.array(P_w.shape, dtype=np.int)
        P_w_shape = self.P_union.broadcast_data(P_w_shape, root=0)

        P_co = P_w_shape[0]
        P_ci = P_w_shape[1]
        P_channels = [P_co, P_ci]

        # Ensure that P_x and P_w are correctly aligned.  We also produce a
        # new P_x that is shaped like 1 x P_ci x P_d-1 x ... x P_0, to assist
        # with broadcasts.
        P_x_new_shape = []
        if self.P_x.active:
            if(np.any(P_x.shape[2:] != P_w_shape[2:])):
                raise ValueError("Spatial components of P_x and P_w must match.")
            if P_w_shape[1] != P_x.shape[1]:
                raise ValueError("Index 2 of P_w dimension must match input channel partition.")
            P_x_new_shape = list(P_x.shape)
            P_x_new_shape.insert(1, 1)
            # Currently a hack, removing the batch dimension because P_w does
            # not have one. This is OK because we assume there are no partitions
            # in the batch dimension.
            P_x_new_shape = np.asarray(P_x_new_shape[1:], dtype=int)

        # For the purposes of this layer, we re-cast P_x to have the extra
        # dimension.  This has no impact outside of the layer or on the results.
        self.P_x = self.P_x.create_cartesian_topology_partition(P_x_new_shape)

        # Ensure that P_y and P_w are correctly aligned.  We also produce a
        # new P_y that is shaped like 1 x P_ci x P_d-1 x ... x P_0, to assist
        # with broadcasts.
        P_y_new_shape = []
        if self.P_y.active:
            if(np.any(P_y.shape[2:] != P_w_shape[2:])):
                raise ValueError("Spatial components of P_y and P_w must match.")
            if P_w_shape[0] != P_y.shape[1]:
                raise ValueError("Index 1 of P_w dimension must match output channel partition.")
            P_y_new_shape = list(P_y.shape)
            P_y_new_shape.insert(2, 1)
            # Currently a hack, removing the batch dimension because P_w does
            # not have one. This is OK because we assume there are no partitions
            # in the batch dimension.
            P_y_new_shape = np.asarray(P_y_new_shape[1:], dtype=int)

        # For the purposes of this layer, we re-cast P_x to have the extra
        # dimension.  This has no impact outside of the layer or on the results.
        self.P_y = self.P_y.create_cartesian_topology_partition(P_y_new_shape)

        P_spatial = P_w_shape[2:]

        self.serial = self.P_w.size == 1

        if self.serial:
            self.conv_layer = self.TorchConvType(in_channels=in_channels,
                                                 out_channels=out_channels,
                                                 kernel_size=self.kernel_size,
                                                 stride=self.stride,
                                                 padding=self.padding,
                                                 padding_mode=self.padding_mode,
                                                 dilation=self.dilation,
                                                 groups=self.groups,
                                                 bias=self.use_bias)
            self.weight = self.conv_layer.weight
            self.bias = self.conv_layer.bias
            return

        # Need to figure out any padding necessary to handle global padding.
        # This is only on the input tensor.  The convolution will not use
        # any implicit padding, so the work partition does not need it.
        if self.P_x.active:
            dims = len(self.P_x.shape)

            # We will be using global padding to compute local padding,
            # so expand it to a numpy array
            global_padding = np.pad(self.padding,
                                    pad_width=(dims-len(self.padding), 0),
                                    mode='constant',
                                    constant_values=0)
            self.global_padding = global_padding

            pad_left_right = self.global_padding.reshape((dims, 1)) + np.zeros((dims, 2), dtype=np.int)
            self.local_padding = self._compute_local_padding(pad_left_right)

        # Workers can either store the learnable weights and bias, or they
        # need copies of it.
        self.receives_weight = False
        self.stores_weight = False
        self.receives_bias = False
        self.stores_bias = False

        # Determine root partitions, initialize weights there
        if self.P_w.active:
            # All of P_w always receives the weight
            self.receives_weight = True

            # This subset is taken to be the origin of the spartial component
            w_root_subset = []
            for i, c in enumerate(range_index(P_w.shape)):
                c = np.asarray(c)
                # Find the P_co x P_ci x 1 x ... x 1 subset to store the weights
                if np.all(c[2:] == 0):
                    w_root_subset.append(i)

            P_wr_base = self.P_w.create_partition_inclusive(w_root_subset)
            # ones are needed so the broadcast will work
            self.P_wr = P_wr_base.create_cartesian_topology_partition([P_co, P_ci] + [1]*len(P_spatial))
            self.stores_weight = self.P_wr.active

            # Release temporary resources
            P_wr_base.deactivate()

            b_subset = []
            for i, c in enumerate(range_index(P_w.shape)):
                c = np.asarray(c)
                # Find the P_co x 1 x P_0 x ... x P_D-1 subset that needs
                # biases in its calculation. This is everywhere that the input
                # channels is rank 0.
                if c[1] == 0:
                    b_subset.append(i)

            P_b_base = self.P_w.create_partition_inclusive(b_subset)
            self.P_b = P_b_base.create_cartesian_topology_partition([P_co] + [1] + list(P_spatial))
            self.receives_bias = self.P_b.active and self.use_bias

            # Release temporary resources
            P_b_base.deactivate()

            # Now find the subset of _that_ which actually stores the
            # learnable parameter.
            b_root_subset = []
            for i, c in enumerate(range_index(P_w.shape)):
                c = np.asarray(c)
            # Find the P_co x 1 x 1 x ... x 1 subset to store the biases
                if np.all(c[1:] == 0):
                    b_root_subset.append(i)

            P_br_base = self.P_w.create_partition_inclusive(b_root_subset)
            # ones are needed so the broadcast will work
            self.P_br = P_br_base.create_cartesian_topology_partition([P_co] + [1] + [1]*len(P_spatial))
            self.stores_bias = self.P_br.active and self.use_bias

            # Release temporary resources
            P_br_base.deactivate()

            # Correct the input arguments based on local properties
            # This ensures that the in and out channels are correctly shared.
            local_co, local_ci = compute_subshape(P_channels,
                                                  P_w.index[0:2],
                                                  [out_channels, in_channels])
            self.conv_layer = self.TorchConvType(in_channels=local_ci,
                                                 out_channels=local_co,
                                                 kernel_size=self.kernel_size,
                                                 stride=self.stride,
                                                 padding=0,
                                                 padding_mode='zeros',
                                                 dilation=self.dilation,
                                                 groups=groups,
                                                 bias=self.receives_bias)

            # If we store the weight it is a learnable parameter iff it is
            # learnable by default in the layer, which it is.
            if self.stores_weight:
                self.weight = torch.nn.Parameter(self.conv_layer.weight.detach())
            else:
                self.register_buffer('weight', zero_volume_tensor())
            # This always exists so we can copy the property
            self.weight.requires_grad = self.conv_layer.weight.requires_grad

            # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2
            new_weight = self.conv_layer.weight.detach() * 0
            new_weight.requires_grad = self.conv_layer.weight.requires_grad
            del self.conv_layer.weight
            self.conv_layer.weight = new_weight

            # If we store the bias, it is a learnable parameter iff it is
            # learnable by default in the layer, which is only true if it
            # exists.
            if self.stores_bias:
                self.bias = torch.nn.Parameter(self.conv_layer.bias.detach())
            else:
                if self.use_bias:
                    self.register_buffer('bias', zero_volume_tensor())
                else:
                    self.register_buffer('bias', None)
            # This does not always exist, but when it does we can copy the
            # property.
            if self.receives_bias:
                self.bias.requires_grad = self.conv_layer.bias.requires_grad

                # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2
                new_bias = self.conv_layer.bias.detach() * 0
                new_bias.requires_grad = self.conv_layer.bias.requires_grad
                del self.conv_layer.bias
                self.conv_layer.bias = new_bias

        else:
            # Workers not in P_w don't have a weight or bias.
            self.register_buffer('weight', zero_volume_tensor())
            if self.use_bias:
                self.register_buffer('bias', zero_volume_tensor())
            else:
                self.register_buffer('bias', None)

        # Now we need to share the kernel structure.  The size of the kernel
        # is always the spatial dimensions.
        self.conv_kernel_size = None
        self.conv_stride = None
        self.conv_padding = None
        self.conv_dilation = None

        # By construction, rank 0 of the union should always have all of this
        # information, because it will always construct a local conv layer. We
        # rely on the local conv layer to properly fill out this information
        # from the defaults.  This info is required for all workers on the
        # input and output partitions because it is needed to construct the
        # halos.  Rank 0 in the union shares it with everyone.
        if self.P_union.rank == 0:
            self.conv_kernel_size = np.array(self.conv_layer.kernel_size, dtype=np.int)
            self.conv_stride = np.array(self.conv_layer.stride, dtype=np.int)
            self.conv_padding = np.array(self.conv_layer.padding, dtype=np.int)
            self.conv_dilation = np.array(self.conv_layer.dilation, dtype=np.int)
        self.conv_kernel_size = self.P_union.broadcast_data(self.conv_kernel_size, root=0)
        self.conv_stride = self.P_union.broadcast_data(self.conv_stride, root=0)
        self.conv_padding = self.P_union.broadcast_data(self.conv_padding, root=0)
        self.conv_dilation = self.P_union.broadcast_data(self.conv_dilation, root=0)

        # We need to be able to remove some data from the input to the conv
        # layer but again need to defer.
        self.needed_slices = None

        # For the halo layer we also defer construction, so that we can have
        # the halo shape for the input.  The halo will allocate its own
        # buffers, but it needs this information at construction to be able
        # to do this in the pre-forward hook.
        self.halo_layer = None

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()

        # Some layers, those that require no information about the input
        # tensor to setup, can be built now.
        if P_w.active:
            self.w_broadcast = Broadcast(self.P_wr, self.P_w, preserve_batch=False)

        if self.receives_bias or self.stores_bias:
            self.b_broadcast = Broadcast(self.P_br, self.P_b, preserve_batch=False)

        self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True)
        self.y_sum_reduce = SumReduce(self.P_w, self.P_y, preserve_batch=True)
예제 #21
0
파일: pooling.py 프로젝트: rhewett/distdl
    def __init__(self,
                 P_x,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 buffer_manager=None):

        super(DistributedPoolBase, self).__init__()

        # P_x is 1 x 1 x P_d-1 x ... x P_0
        self.P_x = P_x

        self.is_avg = self.TorchPoolType in [torch.nn.AvgPool1d, torch.nn.AvgPool2d, torch.nn.AvgPool3d]

        # Back-end specific buffer manager for economic buffer allocation
        if buffer_manager is None:
            buffer_manager = self._distdl_backend.BufferManager()
        elif type(buffer_manager) is not self._distdl_backend.BufferManager:
            raise ValueError("Buffer manager type does not match backend.")
        self.buffer_manager = buffer_manager

        if not self.P_x.active:
            return

        dims = len(self.P_x.shape)

        self.kernel_size = self._expand_parameter(kernel_size)
        self.stride = self._expand_parameter(stride)
        self.padding = self._expand_parameter(padding)
        self.dilation = self._expand_parameter(dilation)

        if self.is_avg and not all(x == 1 for x in self.dilation):
            raise ValueError('dilation is only supported for MaxPooling layers.')

        # PyTorch does not support dilation for AvgPooling layers
        if self.is_avg:
            self.pool_layer = self.TorchPoolType(kernel_size=self.kernel_size,
                                                 stride=self.stride,
                                                 padding=0)
        else:
            self.pool_layer = self.TorchPoolType(kernel_size=self.kernel_size,
                                                 stride=self.stride,
                                                 padding=0,
                                                 dilation=self.dilation)

        # We will be using global padding to compute local padding,
        # so expand it to a numpy array
        global_padding = np.pad(self.padding,
                                pad_width=(dims-len(self.padding), 0),
                                mode='constant',
                                constant_values=0)
        self.global_padding = global_padding

        pad_left_right = self.global_padding.reshape((dims, 1)) + np.zeros((dims, 2), dtype=np.int)
        self.local_padding = self._compute_local_padding(pad_left_right)

        # We need to be able to remove some data from the input to the conv
        # layer.
        self.needed_slices = None

        # For the halo layer we also defer construction, so that we can have
        # the halo shape for the input.  The halo will allocate its own
        # buffers, but it needs this information at construction to be able
        # to do this in the pre-forward hook.
        self.halo_layer = None

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()
예제 #22
0
    def _distdl_module_setup(self, input):
        r"""Distributed (feature) pooling module setup function.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        self._distdl_is_setup = True
        self._input_tensor_structure = TensorStructure(input[0])

        if not self.P_x.active:
            return

        # To compute the halo regions and interpolation, we need the global
        # tensor shape.  This is not available until when the input is
        # provided.
        global_input_tensor_structure = \
            self._distdl_backend.assemble_global_tensor_structure(input[0], self.P_x)

        if self.size is None:
            global_output_tensor_shape = torch.as_tensor(
                global_input_tensor_structure.shape).to(torch.float64)
            global_output_tensor_shape[2:] *= self.scale_factor

            # I prefer ceil(), torch uses floor(), so we go with floor for consistency
            global_output_tensor_shape = torch.Size(
                torch.floor(global_output_tensor_shape).to(torch.int64))
        else:
            if len(self.size) != len(global_input_tensor_structure.shape):
                raise ValueError(
                    "Provided size does not match input tensor dimension.")
            global_output_tensor_shape = torch.Size(torch.as_tensor(self.size))
        global_output_tensor_structure = TensorStructure()
        global_output_tensor_structure.shape = global_output_tensor_shape

        # Using that information, we can get there rest of the halo information
        exchange_info = self._compute_exchange_info(
            self.P_x, global_input_tensor_structure,
            global_output_tensor_structure, self.scale_factor, self.mode,
            self.align_corners)
        halo_shape = exchange_info[0]
        recv_buffer_shape = exchange_info[1]
        send_buffer_shape = exchange_info[2]
        needed_ranges = exchange_info[3]

        self.halo_shape = halo_shape

        # We can also set up part of the halo layer.
        self.halo_layer = HaloExchange(self.P_x,
                                       halo_shape,
                                       recv_buffer_shape,
                                       send_buffer_shape,
                                       buffer_manager=self.buffer_manager)

        # We have to select out the "unused" entries.  Sometimes there can
        # be "negative" halos.
        self.needed_slices = assemble_slices(needed_ranges[:, 0],
                                             needed_ranges[:, 1])

        # TODO #176: This block to compute the start and stop index of the
        # post-halo exchanged input can be cleaned up, as it is a duplicate of
        # calculation in the halo layer itself
        _slice = tuple([slice(i, i + 1)
                        for i in self.P_x.index] + [slice(None)])

        x_subtensor_shapes = compute_subtensor_shapes_balanced(
            global_input_tensor_structure, self.P_x.shape)
        x_subtensor_start_indices = compute_subtensor_start_indices(
            x_subtensor_shapes)
        x_subtensor_stop_indices = compute_subtensor_stop_indices(
            x_subtensor_shapes)

        x_start_index = torch.from_numpy(
            x_subtensor_start_indices[_slice].squeeze())
        x_stop_index = torch.from_numpy(
            x_subtensor_stop_indices[_slice].squeeze())

        y_subtensor_shapes = compute_subtensor_shapes_balanced(
            global_output_tensor_structure, self.P_x.shape)
        y_subtensor_start_indices = compute_subtensor_start_indices(
            y_subtensor_shapes)
        y_subtensor_stop_indices = compute_subtensor_stop_indices(
            y_subtensor_shapes)

        y_start_index = torch.from_numpy(
            y_subtensor_start_indices[_slice].squeeze())
        y_stop_index = torch.from_numpy(
            y_subtensor_stop_indices[_slice].squeeze())

        x_start_index = self._compute_needed_start(
            y_start_index, global_input_tensor_structure.shape,
            global_output_tensor_structure.shape, self.scale_factor, self.mode,
            self.align_corners)

        x_stop_index = self._compute_needed_stop(
            y_stop_index - 1, global_input_tensor_structure.shape,
            global_output_tensor_structure.shape, self.scale_factor, self.mode,
            self.align_corners)

        self.interp_layer = Interpolate(x_start_index,
                                        x_stop_index,
                                        global_input_tensor_structure.shape,
                                        y_start_index,
                                        y_stop_index,
                                        global_output_tensor_structure.shape,
                                        scale_factor=self.scale_factor,
                                        mode=self.mode,
                                        align_corners=self.align_corners)
예제 #23
0
    def __init__(self,
                 P_x,
                 P_y,
                 P_w,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 padding_mode='zeros',
                 dilation=1,
                 groups=1,
                 bias=True,
                 *args,
                 **kwargs):

        super(DistributedChannelConvBase, self).__init__()

        # P_x is 1    x P_ci x 1 x ... x 1
        self.P_x = P_x
        # P_y is 1    x P_co x 1 x ... x 1
        self.P_y = P_y
        # P_w is P_co x P_ci x 1 x ... x 1
        self.P_w = P_w

        # Even inactive workers need some partition union
        P_union = self._distdl_backend.Partition()

        if not (self.P_x.active or self.P_y.active or self.P_w.active):
            return

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = self._expand_parameter(kernel_size)
        self.stride = self._expand_parameter(stride)
        self.padding = self._expand_parameter(padding)
        self.padding_mode = padding_mode
        self.dilation = self._expand_parameter(dilation)
        self.groups = groups
        self.use_bias = bias

        # This guarantees that P_union rank 0 has the kernel size, stride,
        # padding, and dilation factors
        P_union_temp = P_w.create_partition_union(P_x)
        P_union = P_union_temp.create_partition_union(P_y)

        # Ensure that all workers have the full size and structure of P_w
        P_w_shape = None
        if P_union.rank == 0:
            P_w_shape = np.array(P_w.shape, dtype=np.int)
        P_w_shape = P_union.broadcast_data(P_w_shape, root=0)

        # Release the temporary resources
        P_union_temp.deactivate()
        P_union.deactivate()

        P_co = P_w_shape[0]
        P_ci = P_w_shape[1]
        P_channels = [P_co, P_ci]

        # Ensure that P_x and P_w are correctly aligned.  We also produce a
        # new P_x that is shaped like 1 x P_ci x 1 x ... x 1, to assist with
        # broadcasts.
        P_x_new_shape = []
        if self.P_x.active:
            if (np.any(P_x.shape[2:] != P_w_shape[2:])):
                raise ValueError(
                    "Spatial components of P_x and P_w must match.")
            if (np.any(P_x.shape[2:] != np.ones(len(P_x.shape[2:])))):
                raise ValueError(
                    "Spatial components of P_x must be 1 x ... x 1.")
            if P_w_shape[1] != P_x.shape[1]:
                raise ValueError(
                    "Index 2 of P_w dimension must match input channel partition."
                )
            P_x_new_shape = list(P_x.shape)
            P_x_new_shape.insert(1, 1)
            # Currently a hack, removing the batch dimension because P_w does
            # not have one. This is OK because we assume there are no partitions
            # in the batch dimension.
            P_x_new_shape = np.asarray(P_x_new_shape[1:], dtype=int)

        # For the purposes of this layer, we re-cast P_x to have the extra
        # dimension.  This has no impact outside of the layer or on the results.
        self.P_x = self.P_x.create_cartesian_topology_partition(P_x_new_shape)

        # Ensure that P_y and P_w are correctly aligned.  We also produce a
        # new P_y that is shaped like P_co x 1 x 1 x ... x 1, to assist with
        # broadcasts.
        P_y_new_shape = []
        if self.P_y.active:
            if (np.any(P_y.shape[2:] != P_w_shape[2:])):
                raise ValueError(
                    "Spatial components of P_y and P_w must match.")
            if (np.any(P_y.shape[2:] != np.ones(len(P_y.shape[2:])))):
                raise ValueError(
                    "Spatial components of P_y must be 1 x ... x 1.")
            if P_w_shape[0] != P_y.shape[1]:
                raise ValueError(
                    "Index 1 of P_w dimension must match output channel partition."
                )
            P_y_new_shape = list(P_y.shape)
            P_y_new_shape.insert(2, 1)
            # Currently a hack, removing the batch dimension because P_w does
            # not have one. This is OK because we assume there are no partitions
            # in the batch dimension.
            P_y_new_shape = np.asarray(P_y_new_shape[1:], dtype=int)

        # For the purposes of this layer, we re-cast P_x to have the extra
        # dimension.  This has no impact outside of the layer or on the results.
        self.P_y = self.P_y.create_cartesian_topology_partition(P_y_new_shape)

        self.serial = self.P_w.size == 1

        if self.serial:
            self.conv_layer = self.TorchConvType(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding,
                padding_mode=self.padding_mode,
                dilation=self.dilation,
                groups=self.groups,
                bias=self.use_bias)
            self.weight = self.conv_layer.weight
            self.bias = self.conv_layer.bias
            return

        # Flag if the global bias is set
        self.global_bias = bias

        # Flags if current worker stores (part of) the bias locally.
        self.stores_bias = False

        if self.P_w.active:

            # Let the P_co column store the bias if it is to be used
            self.stores_bias = self.P_w.index[1] == 0 and self.use_bias

            # Correct the input arguments based on local properties
            # This ensures that the in and out channels are correctly shared.
            local_co, local_ci = compute_subshape(P_channels, P_w.index[0:2],
                                                  [out_channels, in_channels])
            self.conv_layer = self.TorchConvType(
                in_channels=local_ci,
                out_channels=local_co,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding,
                padding_mode=self.padding_mode,
                dilation=self.dilation,
                groups=groups,
                bias=self.stores_bias)

        # Workers in P_w alias the conv layer to get their weight and perhaps
        # biases.  Every other worker doesn't have a weight or bias.
        if self.P_w.active:
            self.weight = self.conv_layer.weight
            if self.stores_bias:
                self.bias = self.conv_layer.bias
            else:
                if self.use_bias:
                    self.register_buffer('bias', zero_volume_tensor())
                else:
                    self.register_buffer('bias', None)
        else:
            self.register_buffer('weight', zero_volume_tensor())
            if self.use_bias:
                self.register_buffer('bias', zero_volume_tensor())
            else:
                self.register_buffer('bias', None)

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()

        self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True)
        self.y_sum_reduce = SumReduce(self.P_w, self.P_y, preserve_batch=True)
예제 #24
0
    def __init__(self,
                 P_x,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 padding_mode='zeros',
                 dilation=1,
                 groups=1,
                 bias=True,
                 buffer_manager=None):

        super(DistributedFeatureConvBase, self).__init__()

        # P_x is 1 x 1 x P_d-1 x ... x P_0
        self.P_x = P_x

        # Back-end specific buffer manager for economic buffer allocation
        if buffer_manager is None:
            buffer_manager = self._distdl_backend.BufferManager()
        elif type(buffer_manager) is not self._distdl_backend.BufferManager:
            raise ValueError("Buffer manager type does not match backend.")
        self.buffer_manager = buffer_manager

        if not self.P_x.active:
            return

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = self._expand_parameter(kernel_size)
        self.stride = self._expand_parameter(stride)
        self.padding = self._expand_parameter(padding)
        self.padding_mode = padding_mode
        self.dilation = self._expand_parameter(dilation)
        self.groups = groups
        self.use_bias = bias

        self.serial = self.P_x.size == 1

        if self.serial:
            self.conv_layer = self.TorchConvType(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding,
                padding_mode=self.padding_mode,
                dilation=self.dilation,
                groups=self.groups,
                bias=self.use_bias)
            self.weight = self.conv_layer.weight
            self.bias = self.conv_layer.bias
        else:
            self.conv_layer = self.TorchConvType(in_channels=in_channels,
                                                 out_channels=out_channels,
                                                 kernel_size=self.kernel_size,
                                                 stride=self.stride,
                                                 padding=0,
                                                 padding_mode='zeros',
                                                 dilation=self.dilation,
                                                 groups=groups,
                                                 bias=bias)

        if self.serial:
            return

        dims = len(self.P_x.shape)

        # We will be using global padding to compute local padding,
        # so expand it to a numpy array
        global_padding = np.pad(self.padding,
                                pad_width=(dims - len(self.padding), 0),
                                mode='constant',
                                constant_values=0)
        self.global_padding = global_padding

        pad_left_right = self.global_padding.reshape((dims, 1)) + np.zeros(
            (dims, 2), dtype=np.int)
        self.local_padding = self._compute_local_padding(pad_left_right)

        # Weights and biases partition
        P_wb = self.P_x.create_partition_inclusive([0])
        self.P_wb_cart = P_wb.create_cartesian_topology_partition([1])

        # Release temporary resources
        P_wb.deactivate()

        # We want only the root rank of the broadcast to have a weight and a
        # bias parameter. Every other rank gets a zero-volume tensor.
        if self.P_wb_cart.active:
            self.weight = torch.nn.Parameter(self.conv_layer.weight.detach())

            if self.conv_layer.bias is not None:
                self.bias = torch.nn.Parameter(self.conv_layer.bias.detach())
            else:
                self.register_buffer('bias', None)
        else:
            self.register_buffer('weight', zero_volume_tensor())

            if self.conv_layer.bias is not None:
                self.register_buffer('bias', zero_volume_tensor())
            else:
                self.register_buffer('bias', None)

        self.weight.requires_grad = self.conv_layer.weight.requires_grad

        if self.conv_layer.bias is not None:
            self.bias.requires_grad = self.conv_layer.bias.requires_grad

        # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2
        new_weight = self.conv_layer.weight.detach() * 0
        new_weight.requires_grad = self.conv_layer.weight.requires_grad
        del self.conv_layer.weight
        self.conv_layer.weight = new_weight

        if self.conv_layer.bias is not None:
            new_bias = self.conv_layer.bias.detach() * 0
            new_bias.requires_grad = self.conv_layer.bias.requires_grad
            del self.conv_layer.bias
            self.conv_layer.bias = new_bias

        self.w_broadcast = Broadcast(self.P_wb_cart,
                                     self.P_x,
                                     preserve_batch=False)

        if self.conv_layer.bias is not None:
            self.b_broadcast = Broadcast(self.P_wb_cart,
                                         self.P_x,
                                         preserve_batch=False)

        # We need to be able to remove some data from the input to the conv
        # layer.
        self.needed_slices = None

        # For the halo layer we also defer construction, so that we can have
        # the halo shape for the input.  The halo will allocate its own
        # buffers, but it needs this information at construction to be able
        # to do this in the pre-forward hook.
        self.halo_layer = None

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()
예제 #25
0
    def _distdl_module_setup(self, input):
        r"""Distributed (feature) convolution module setup function.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        self._distdl_is_setup = True
        self._input_tensor_structure = TensorStructure(input[0])

        if not self.P_x.active:
            return

        if self.serial:
            return

        # Compute global and local shapes with padding
        x_global_structure = \
            self._distdl_backend.assemble_global_tensor_structure(input[0], self.P_x)
        x_local_structure = TensorStructure(input[0])
        x_global_shape = x_global_structure.shape
        x_local_shape = x_local_structure.shape
        x_global_shape_after_pad = x_global_shape + 2 * self.global_padding
        x_local_shape_after_pad = x_local_shape + np.sum(
            self.local_padding, axis=1, keepdims=False)
        x_local_structure_after_pad = TensorStructure(input[0])
        x_local_structure_after_pad.shape = x_local_shape_after_pad

        # We need to compute the halos with respect to the explicit padding.
        # So, we assume the padding is already added, then compute the halo regions.
        compute_subtensor_shapes_unbalanced = \
            self._distdl_backend.tensor_decomposition.compute_subtensor_shapes_unbalanced
        subtensor_shapes = \
            compute_subtensor_shapes_unbalanced(x_local_structure_after_pad, self.P_x)

        # Using that information, we can get there rest of the halo information
        exchange_info = self._compute_exchange_info(
            x_global_shape_after_pad,
            self.kernel_size,
            self.stride,
            self._expand_parameter(0),
            self.dilation,
            self.P_x.active,
            self.P_x.shape,
            self.P_x.index,
            subtensor_shapes=subtensor_shapes)
        halo_shape = exchange_info[0]
        recv_buffer_shape = exchange_info[1]
        send_buffer_shape = exchange_info[2]
        needed_ranges = exchange_info[3]

        self.halo_shape = halo_shape

        # We can also set up part of the halo layer.
        self.halo_layer = HaloExchange(self.P_x,
                                       halo_shape,
                                       recv_buffer_shape,
                                       send_buffer_shape,
                                       buffer_manager=self.buffer_manager)

        # We have to select out the "unused" entries.  Sometimes there can
        # be "negative" halos.
        self.needed_slices = assemble_slices(needed_ranges[:, 0],
                                             needed_ranges[:, 1])
예제 #26
0
class DistributedTranspose(Module):
    r"""A distributed transpose layer.

    This class provides the user interface to the transpose distributed data
    movement primitive.  Implementation details are back-end specific.

    The Transpose algorithm performs a transpose, shuffle, or generalized
    all-to-all from a tensor partitioned with by `P_x` to a new tensor
    partitioned with `P_y`.  The values of the tensor do not change.  Only the
    distribution of the tensor over the workers changes.

    If ``P_x`` and ``P_y`` are exactly equal, then no data movement occurs.

    For input and output tensors that have a batch dimension, the batch
    dimension needs to be preserved.  If a tensor does not have a batch
    dimension, we should not preserve that for zero-volume outputs.  The
    `preserve_batch` option controls this.

    Parameters
    ----------
    P_x :
        Partition of input tensor.
    P_y :
        Partition of output tensor.
    preserve_batch : bool, optional
        Indicates if batch size should be preserved for zero-volume outputs.
    buffer_manager : optional
        External manager for communication buffers

    """
    def __init__(self, P_x, P_y, preserve_batch=True, buffer_manager=None):
        super(DistributedTranspose, self).__init__()

        # Global structure of the input tensor, assembled when layer is called
        self.global_input_tensor_structure = TensorStructure()

        # Local input and output tensor structures, defined when layer is called
        self.input_tensor_structure = TensorStructure()
        self.output_tensor_structure = TensorStructure()

        # Partition of input tensor.
        self.P_x = P_x

        # Partition of output tensor.
        self.P_y = P_y

        # Indicates if batch size should be preserved for zero-volume outputs.
        self.preserve_batch = preserve_batch

        # List of meta data describing copies of subvolumes of input tensor
        # out of the current worker
        self.P_x_to_y_overlaps = []

        # List of meta data describing copies of subvolumes of output tensor
        # into the current worker
        self.P_y_to_x_overlaps = []

        # Back-end specific buffer manager for economic buffer allocation
        if buffer_manager is None:
            buffer_manager = self._distdl_backend.BufferManager()
        elif type(buffer_manager) is not self._distdl_backend.BufferManager:
            raise ValueError("Buffer manager type does not match backend.")
        self.buffer_manager = buffer_manager

        # List of buffers for copying data to other workers
        self.P_x_to_y_buffers = None

        # List of buffers for copying data from other workers
        self.P_y_to_x_buffers = None

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()

        # Otherwise, we need the union of the input and output partitions
        # so that data can be copied across them.
        P_union = self._distdl_backend.Partition()
        if P_x.active or P_y.active:
            P_union = P_x.create_partition_union(P_y)
        self.P_union = P_union

        # Setup these variables incase the current worker is inactive in
        # the union.
        self.P_x_shape = None
        self.P_y_shape = None
        self.union_indices = None

        if not P_union.active:
            return

        # All active workers need the shapes of both partitions so that buffer
        # sizes and subtensor overlaps can be computed.
        data = None
        if self.P_x.active:
            data = self.P_x.shape
        self.P_x_shape = self.P_union.broadcast_data(data, P_data=self.P_x)

        data = None
        if self.P_y.active:
            data = self.P_y.shape
        self.P_y_shape = self.P_union.broadcast_data(data, P_data=self.P_y)

        if len(self.P_x_shape) != len(self.P_y_shape):
            raise ValueError(
                "Input and output partition must be same dimension.")

        # Share the two indices with every worker in the union.  The first
        # column of data contains the source index and the second contains
        # the destination index.
        data = np.array([P_x.rank if P_x.active else -1], dtype=np.int)
        self.P_x_ranks = P_union.allgather_data(data)

        data = np.array([P_y.rank if P_y.active else -1], dtype=np.int)
        self.P_y_ranks = P_union.allgather_data(data)

        # Get some types and functions from the back-end
        self.allocate_transpose_buffers = self._distdl_backend.transpose.allocate_transpose_buffers

    def _distdl_module_setup(self, input):
        r"""Transpose module setup function.

        Constructs the necessary buffers and meta information about outbound
        and inbound copies to each worker.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        self._distdl_is_setup = True
        self._input_tensor_structure.fill_from_tensor(input[0])

        # If we are not an active worker, do nothing.
        if not self.P_union.active:
            return

        self.input_tensor_structure = TensorStructure(input[0])

        self.global_input_tensor_structure = \
            self._distdl_backend.assemble_global_tensor_structure(self.input_tensor_structure,
                                                                  self.P_x,
                                                                  self.P_union)
        x_global_shape = self.global_input_tensor_structure.shape

        if self.P_y.active:
            self.output_tensor_structure.shape = compute_subshape(
                self.P_y.shape, self.P_y.index, x_global_shape)

        tensor_dim = len(x_global_shape)

        if len(self.P_x_shape) != tensor_dim:
            raise ValueError(
                f"Input partition mush have same dimension "
                f"({len(self.P_x_shape)}) as input tensor rank ({tensor_dim})."
            )

        if len(self.P_y_shape) != tensor_dim:
            raise ValueError(
                f"Output partition mush have same dimension "
                f"({len(self.P_y_shape)}) as input tensor rank ({tensor_dim})."
            )

        if 1 in x_global_shape[x_global_shape != self.P_y_shape]:
            raise ValueError(
                f"Input tensor must not be size 1 "
                f"({x_global_shape}) in a dimension where "
                f"output partition is other than 1 ({self.P_y_shape}).")

        # Get the collective input lengths and origins. This may be load
        # balanced or it may not be.  Therefore we will always assume is is
        # not load balanced and just build the subshape tensor manually.
        # This output is needed everywhere so it goes to P_union.
        compute_subtensor_shapes_unbalanced = \
            self._distdl_backend.tensor_decomposition.compute_subtensor_shapes_unbalanced
        x_subtensor_shapes = compute_subtensor_shapes_unbalanced(
            self.input_tensor_structure, self.P_x, self.P_union)

        # Get the collective output lengths and origins. This will always be
        # load balanced, so we can infer the subshape tensor from the global
        # tensor shape and the shape of P_y.  At this point, every worker in
        # P_union has both of these pieces of information, so we can build it
        # with no communication.

        y_subtensor_shapes = compute_subtensor_shapes_balanced(
            self.global_input_tensor_structure, self.P_y_shape)

        # Given all subtensor shapes, we can compute the start and stop indices
        # for each partition.

        x_subtensor_start_indices = compute_subtensor_start_indices(
            x_subtensor_shapes)
        x_subtensor_stop_indices = compute_subtensor_stop_indices(
            x_subtensor_shapes)

        y_subtensor_start_indices = compute_subtensor_start_indices(
            y_subtensor_shapes)
        y_subtensor_stop_indices = compute_subtensor_stop_indices(
            y_subtensor_shapes)

        # We only need to move data to the output partition if we actually
        # have input data.  It is possible to have both input and output data,
        # either input or output data, or neither.  Hence the active guard.
        if self.P_x.active:
            x_slice = tuple([slice(i, i + 1)
                             for i in self.P_x.index] + [slice(None)])
            x_start_index = x_subtensor_start_indices[x_slice].squeeze()
            x_stop_index = x_subtensor_stop_indices[x_slice].squeeze()

            # Compute our overlaps for each output subpartition.
            for rank, P_y_index in enumerate(range_index(self.P_y_shape)):

                y_slice = tuple([slice(i, i + 1)
                                 for i in P_y_index] + [slice(None)])
                y_start_index = y_subtensor_start_indices[y_slice].squeeze()
                y_stop_index = y_subtensor_stop_indices[y_slice].squeeze()

                sl = compute_subtensor_intersection_slice(
                    x_start_index, x_stop_index, y_start_index, y_stop_index)

                if sl is not None:
                    sh = compute_nd_slice_shape(sl)
                    # If it is a self-copy, mark it so we don't have to create
                    # a potentially large buffer
                    if self.P_y.active and np.all(P_y_index == self.P_y.index):
                        partner = "self"
                    # Otherwise, reverse the mapping to get the output
                    # partner's rank in the common partition.
                    else:
                        partner = np.where(self.P_y_ranks == rank)[0][0]

                    self.P_x_to_y_overlaps.append((sl, sh, partner))

                else:
                    self.P_x_to_y_overlaps.append((None, None, None))

        # We only need to obtain data from the input partition if we actually
        # have output data.
        if self.P_y.active:
            y_slice = tuple([slice(i, i + 1)
                             for i in self.P_y.index] + [slice(None)])
            y_start_index = y_subtensor_start_indices[y_slice].squeeze()
            y_stop_index = y_subtensor_stop_indices[y_slice].squeeze()

            # Compute our overlaps for each input subpartition.
            for rank, P_x_index in enumerate(range_index(self.P_x_shape)):

                x_slice = tuple([slice(i, i + 1)
                                 for i in P_x_index] + [slice(None)])
                x_start_index = x_subtensor_start_indices[x_slice].squeeze()
                x_stop_index = x_subtensor_stop_indices[x_slice].squeeze()

                sl = compute_subtensor_intersection_slice(
                    y_start_index, y_stop_index, x_start_index, x_stop_index)

                if sl is not None:
                    sh = compute_nd_slice_shape(sl)
                    # If it is a self-copy, mark it so we don't have to create
                    # a potentially large buffer
                    if self.P_x.active and np.all(P_x_index == self.P_x.index):
                        partner = "self"
                    # Otherwise, reverse the mapping to get the output
                    # partner's rank in the common partition.
                    else:
                        partner = np.where(self.P_x_ranks == rank)[0][0]

                    self.P_y_to_x_overlaps.append((sl, sh, partner))

                else:
                    self.P_y_to_x_overlaps.append((None, None, None))

        buffs = self.allocate_transpose_buffers(
            self.buffer_manager, self.P_x_to_y_overlaps,
            self.P_y_to_x_overlaps, self.global_input_tensor_structure.dtype)
        self.P_x_to_y_buffers = buffs[0]
        self.P_y_to_x_buffers = buffs[1]

    def _distdl_module_teardown(self, input):
        r"""Transpose module teardown function.

        Deallocates buffers safely.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        # Reset all of the buffers and communication objects
        self.P_x_to_y_overlaps = []
        self.P_y_to_x_overlaps = []

        self.P_x_to_y_buffers = None
        self.P_y_to_x_buffers = None

        # Reset any info about the input
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()

    def _distdl_input_changed(self, input):
        r"""Determine if the structure of inputs has changed.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        new_tensor_structure = TensorStructure(input[0])

        return self._input_tensor_structure != new_tensor_structure

    def forward(self, input):
        """Forward function interface.

        Parameters
        ----------
        input :
            Input tensor to be broadcast.

        """

        Function = self._distdl_backend.functional.transpose.DistributedTransposeFunction

        # If this is an identity operation (no communication necessary),
        # simply return a clone of the input.

        # If this worker is not active for the input or output, then the input
        # should be a zero-volume tensor, and the output should be the same.
        if not (self.P_x.active or self.P_y.active):
            return input.clone()

        return Function.apply(
            input, self.P_union, self.global_input_tensor_structure,
            self.input_tensor_structure, self.output_tensor_structure,
            self.P_x, self.P_x_to_y_overlaps, self.P_x_to_y_buffers, self.P_y,
            self.P_y_to_x_overlaps, self.P_y_to_x_buffers, self.preserve_batch)
예제 #27
0
    def _distdl_module_setup(self, input):
        r"""Transpose module setup function.

        Constructs the necessary buffers and meta information about outbound
        and inbound copies to each worker.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        self._distdl_is_setup = True
        self._input_tensor_structure.fill_from_tensor(input[0])

        # If we are not an active worker, do nothing.
        if not self.P_union.active:
            return

        self.input_tensor_structure = TensorStructure(input[0])

        self.global_input_tensor_structure = \
            self._distdl_backend.assemble_global_tensor_structure(self.input_tensor_structure,
                                                                  self.P_x,
                                                                  self.P_union)
        x_global_shape = self.global_input_tensor_structure.shape

        if self.P_y.active:
            self.output_tensor_structure.shape = compute_subshape(
                self.P_y.shape, self.P_y.index, x_global_shape)

        tensor_dim = len(x_global_shape)

        if len(self.P_x_shape) != tensor_dim:
            raise ValueError(
                f"Input partition mush have same dimension "
                f"({len(self.P_x_shape)}) as input tensor rank ({tensor_dim})."
            )

        if len(self.P_y_shape) != tensor_dim:
            raise ValueError(
                f"Output partition mush have same dimension "
                f"({len(self.P_y_shape)}) as input tensor rank ({tensor_dim})."
            )

        if 1 in x_global_shape[x_global_shape != self.P_y_shape]:
            raise ValueError(
                f"Input tensor must not be size 1 "
                f"({x_global_shape}) in a dimension where "
                f"output partition is other than 1 ({self.P_y_shape}).")

        # Get the collective input lengths and origins. This may be load
        # balanced or it may not be.  Therefore we will always assume is is
        # not load balanced and just build the subshape tensor manually.
        # This output is needed everywhere so it goes to P_union.
        compute_subtensor_shapes_unbalanced = \
            self._distdl_backend.tensor_decomposition.compute_subtensor_shapes_unbalanced
        x_subtensor_shapes = compute_subtensor_shapes_unbalanced(
            self.input_tensor_structure, self.P_x, self.P_union)

        # Get the collective output lengths and origins. This will always be
        # load balanced, so we can infer the subshape tensor from the global
        # tensor shape and the shape of P_y.  At this point, every worker in
        # P_union has both of these pieces of information, so we can build it
        # with no communication.

        y_subtensor_shapes = compute_subtensor_shapes_balanced(
            self.global_input_tensor_structure, self.P_y_shape)

        # Given all subtensor shapes, we can compute the start and stop indices
        # for each partition.

        x_subtensor_start_indices = compute_subtensor_start_indices(
            x_subtensor_shapes)
        x_subtensor_stop_indices = compute_subtensor_stop_indices(
            x_subtensor_shapes)

        y_subtensor_start_indices = compute_subtensor_start_indices(
            y_subtensor_shapes)
        y_subtensor_stop_indices = compute_subtensor_stop_indices(
            y_subtensor_shapes)

        # We only need to move data to the output partition if we actually
        # have input data.  It is possible to have both input and output data,
        # either input or output data, or neither.  Hence the active guard.
        if self.P_x.active:
            x_slice = tuple([slice(i, i + 1)
                             for i in self.P_x.index] + [slice(None)])
            x_start_index = x_subtensor_start_indices[x_slice].squeeze()
            x_stop_index = x_subtensor_stop_indices[x_slice].squeeze()

            # Compute our overlaps for each output subpartition.
            for rank, P_y_index in enumerate(range_index(self.P_y_shape)):

                y_slice = tuple([slice(i, i + 1)
                                 for i in P_y_index] + [slice(None)])
                y_start_index = y_subtensor_start_indices[y_slice].squeeze()
                y_stop_index = y_subtensor_stop_indices[y_slice].squeeze()

                sl = compute_subtensor_intersection_slice(
                    x_start_index, x_stop_index, y_start_index, y_stop_index)

                if sl is not None:
                    sh = compute_nd_slice_shape(sl)
                    # If it is a self-copy, mark it so we don't have to create
                    # a potentially large buffer
                    if self.P_y.active and np.all(P_y_index == self.P_y.index):
                        partner = "self"
                    # Otherwise, reverse the mapping to get the output
                    # partner's rank in the common partition.
                    else:
                        partner = np.where(self.P_y_ranks == rank)[0][0]

                    self.P_x_to_y_overlaps.append((sl, sh, partner))

                else:
                    self.P_x_to_y_overlaps.append((None, None, None))

        # We only need to obtain data from the input partition if we actually
        # have output data.
        if self.P_y.active:
            y_slice = tuple([slice(i, i + 1)
                             for i in self.P_y.index] + [slice(None)])
            y_start_index = y_subtensor_start_indices[y_slice].squeeze()
            y_stop_index = y_subtensor_stop_indices[y_slice].squeeze()

            # Compute our overlaps for each input subpartition.
            for rank, P_x_index in enumerate(range_index(self.P_x_shape)):

                x_slice = tuple([slice(i, i + 1)
                                 for i in P_x_index] + [slice(None)])
                x_start_index = x_subtensor_start_indices[x_slice].squeeze()
                x_stop_index = x_subtensor_stop_indices[x_slice].squeeze()

                sl = compute_subtensor_intersection_slice(
                    y_start_index, y_stop_index, x_start_index, x_stop_index)

                if sl is not None:
                    sh = compute_nd_slice_shape(sl)
                    # If it is a self-copy, mark it so we don't have to create
                    # a potentially large buffer
                    if self.P_x.active and np.all(P_x_index == self.P_x.index):
                        partner = "self"
                    # Otherwise, reverse the mapping to get the output
                    # partner's rank in the common partition.
                    else:
                        partner = np.where(self.P_x_ranks == rank)[0][0]

                    self.P_y_to_x_overlaps.append((sl, sh, partner))

                else:
                    self.P_y_to_x_overlaps.append((None, None, None))

        buffs = self.allocate_transpose_buffers(
            self.buffer_manager, self.P_x_to_y_overlaps,
            self.P_y_to_x_overlaps, self.global_input_tensor_structure.dtype)
        self.P_x_to_y_buffers = buffs[0]
        self.P_y_to_x_buffers = buffs[1]