示例#1
0
    def _distdl_module_setup(self, input):

        if not (self.P_x.active or self.P_y.active or self.P_w.active):
            return

        if self.serial:
            return

        x_global_shape = self._distdl_backend.compute_global_tensor_shape(
            input[0], self.P_x, self.P_union)
        if self.P_x.active:
            exchange_info = self._compute_exchange_info(
                x_global_shape, self.conv_kernel_size, self.conv_stride,
                self.conv_padding, self.conv_dilation, self.P_x.active,
                self.P_x.shape, self.P_x.index)
            halo_shape = exchange_info[0]
            recv_buffer_shape = exchange_info[1]
            send_buffer_shape = exchange_info[2]
            needed_ranges = exchange_info[3]

            # Now we have enough information to instantiate the padding shim
            self.pad_layer = PadNd(halo_shape, value=0)

            # We can also set up part of the halo layer.
            self.halo_layer = HaloExchange(self.P_x, halo_shape,
                                           recv_buffer_shape,
                                           send_buffer_shape)

            # We have to select out the "unused" entries.
            self.needed_slices = assemble_slices(needed_ranges[:, 0],
                                                 needed_ranges[:, 1])

        # The output has to do some unpadding
        if self.P_y.active:

            # This is safe because there are never halos on the channel or batch
            # dimensions.  Therefore, because we assume that the spatial partition
            # of P_x and P_y is the same, then the halo shape this will
            # compute will also be the same.
            exchange_info = self._compute_exchange_info(
                x_global_shape, self.conv_kernel_size, self.conv_stride,
                self.conv_padding, self.conv_dilation, self.P_y.active,
                self.P_y.shape, self.P_y.index)
            y_halo_shape = exchange_info[0]

            # Unpad shape is padding in the dimensions where we have a halo,
            # otherwise 0
            conv_padding = np.concatenate(([0, 0], self.conv_padding))
            unpad_shape = []
            for pad, halo in zip(conv_padding, y_halo_shape):
                unpad_shape.append(np.where(halo > 0, pad, 0))
            unpad_shape = np.asarray(unpad_shape)

            self.unpad_layer = UnpadNd(unpad_shape, value=0)

        self._distdl_is_setup = True
        self._input_shape = input[0].shape
        self._input_requires_grad = input[0].requires_grad
示例#2
0
def compute_subtensor_intersection_slice(x_start_index, x_stop_index,
                                         y_start_index, y_stop_index):
    r"""Given index bounds (start and stop indices), compute the overlap of
        two Cartesian indexed regions.

        The start and stop index sets of two D-dimensional regions, x and y,
        are sufficient to determine the indices of their overlaps, if any.

        Parameters
        ----------
        x_start_index : iterable
            D starting indices of the first tensor.
        x_stop_index : iterable
            D stopping indices of the first tensor.
        y_start_index : iterable
            D starting indices of the second tensor.
        y_stop_index : iterable
            D stopping indices of the second tensor.

        Returns
        -------
        Tuple of slice objects describing the overlap, relative to the first
        (x) tensor if there is non-zero overlap, `None` otherwise.

    """

    x_start_index = np.atleast_1d(x_start_index)
    x_stop_index = np.atleast_1d(x_stop_index)
    y_start_index = np.atleast_1d(y_start_index)
    y_stop_index = np.atleast_1d(y_stop_index)

    if (len(x_start_index.shape) != 1) or \
       (len(x_stop_index.shape) != 1) or \
       (len(y_start_index.shape) != 1) or \
       (len(y_stop_index.shape) != 1):
        raise ValueError("Index lists must be covnertable to 1-dimensional arrays.")

    # Compute the intersection between the x and y subtensors and its volume
    i_start_index, i_stop_index, i_shape = compute_intersection(x_start_index,
                                                                x_stop_index,
                                                                y_start_index,
                                                                y_stop_index)
    i_volume = np.prod(i_shape)

    # If the volume of the intersection is 0, we have no slice, otherwise we
    # need to determine the slices for the intersection relative to
    # coordinates of x.

    if i_volume == 0:
        return None
    else:
        i_start_index_rel_x = i_start_index - x_start_index
        i_stop_index_rel_x = i_start_index_rel_x + i_shape
        i_slice_rel_x = assemble_slices(i_start_index_rel_x, i_stop_index_rel_x)
        return i_slice_rel_x
示例#3
0
文件: conv.py 项目: denfromufa/distdl
    def _distdl_module_setup(self, input):

        self._distdl_is_setup = True
        self._input_shape = input[0].shape
        self._input_requires_grad = input[0].requires_grad

        if not self.P_x.active:
            return

        if self.serial:
            return

        x_global_shape = self._distdl_backend.compute_global_tensor_shape(
            input[0], self.P_x)

        exchange_info = self._compute_exchange_info(
            x_global_shape, self.conv_layer.kernel_size,
            self.conv_layer.stride, self.conv_layer.padding,
            self.conv_layer.dilation, self.P_x.active, self.P_x.shape,
            self.P_x.index)
        halo_shape = exchange_info[0]
        recv_buffer_shape = exchange_info[1]
        send_buffer_shape = exchange_info[2]
        needed_ranges = exchange_info[3]

        # Now we have enough information to instantiate the padding shim
        self.pad_layer = PadNd(halo_shape, value=0)

        # We can also set up part of the halo layer.
        self.halo_layer = HaloExchange(self.P_x, halo_shape, recv_buffer_shape,
                                       send_buffer_shape)

        # We have to select out the "unused" entries.
        self.needed_slices = assemble_slices(needed_ranges[:, 0],
                                             needed_ranges[:, 1])

        # Unpad shape are conv layer's padding in the dimensions where we have
        # a halo, otherwise 0.  There is no halo in the batch and channel
        # dimensions.
        conv_padding = np.concatenate(([0, 0], self.conv_layer.padding))
        unpad_shape = []
        for pad, halo in zip(conv_padding, halo_shape):
            unpad_shape.append(np.where(halo > 0, pad, 0))
        unpad_shape = np.asarray(unpad_shape)

        self.unpad_layer = UnpadNd(unpad_shape, value=0)
示例#4
0
    def _distdl_module_setup(self, input):

        self._distdl_is_setup = True
        self._input_shape = input[0].shape
        self._input_requires_grad = input[0].requires_grad

        if not self.P_x.active:
            return

        x_global_shape = self._distdl_backend.compute_global_tensor_shape(input[0],
                                                                          self.P_x)
        self.x_global_shape = x_global_shape

        exchange_info = self._compute_exchange_info(x_global_shape,
                                                    self.pool_layer.kernel_size,
                                                    self.pool_layer.stride,
                                                    self.pool_layer.padding,
                                                    [1],  # torch pooling layers have no dilation
                                                    self.P_x.active,
                                                    self.P_x.shape,
                                                    self.P_x.index)
        halo_shape = exchange_info[0]
        recv_buffer_shape = exchange_info[1]
        send_buffer_shape = exchange_info[2]
        needed_ranges = exchange_info[3]

        # Now we have enough information to instantiate the padding shim
        self.pad_layer = PadNd(halo_shape, value=0)

        # We can also set up part of the halo layer.
        self.halo_layer = HaloExchange(self.P_x,
                                       halo_shape,
                                       recv_buffer_shape,
                                       send_buffer_shape)

        # We have to select out the "unused" entries.
        self.needed_slices = assemble_slices(needed_ranges[:, 0],
                                             needed_ranges[:, 1])
示例#5
0
    def _distdl_module_setup(self, input):
        r"""Distributed (feature) pooling module setup function.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        self._distdl_is_setup = True
        self._input_tensor_structure = TensorStructure(input[0])

        if not self.P_x.active:
            return

        # To compute the halo regions and interpolation, we need the global
        # tensor shape.  This is not available until when the input is
        # provided.
        global_input_tensor_structure = \
            self._distdl_backend.assemble_global_tensor_structure(input[0], self.P_x)

        if self.size is None:
            global_output_tensor_shape = torch.as_tensor(
                global_input_tensor_structure.shape).to(torch.float64)
            global_output_tensor_shape[2:] *= self.scale_factor

            # I prefer ceil(), torch uses floor(), so we go with floor for consistency
            global_output_tensor_shape = torch.Size(
                torch.floor(global_output_tensor_shape).to(torch.int64))
        else:
            if len(self.size) != len(global_input_tensor_structure.shape):
                raise ValueError(
                    "Provided size does not match input tensor dimension.")
            global_output_tensor_shape = torch.Size(torch.as_tensor(self.size))
        global_output_tensor_structure = TensorStructure()
        global_output_tensor_structure.shape = global_output_tensor_shape

        # Using that information, we can get there rest of the halo information
        exchange_info = self._compute_exchange_info(
            self.P_x, global_input_tensor_structure,
            global_output_tensor_structure, self.scale_factor, self.mode,
            self.align_corners)
        halo_shape = exchange_info[0]
        recv_buffer_shape = exchange_info[1]
        send_buffer_shape = exchange_info[2]
        needed_ranges = exchange_info[3]

        self.halo_shape = halo_shape

        # We can also set up part of the halo layer.
        self.halo_layer = HaloExchange(self.P_x,
                                       halo_shape,
                                       recv_buffer_shape,
                                       send_buffer_shape,
                                       buffer_manager=self.buffer_manager)

        # We have to select out the "unused" entries.  Sometimes there can
        # be "negative" halos.
        self.needed_slices = assemble_slices(needed_ranges[:, 0],
                                             needed_ranges[:, 1])

        # TODO #176: This block to compute the start and stop index of the
        # post-halo exchanged input can be cleaned up, as it is a duplicate of
        # calculation in the halo layer itself
        _slice = tuple([slice(i, i + 1)
                        for i in self.P_x.index] + [slice(None)])

        x_subtensor_shapes = compute_subtensor_shapes_balanced(
            global_input_tensor_structure, self.P_x.shape)
        x_subtensor_start_indices = compute_subtensor_start_indices(
            x_subtensor_shapes)
        x_subtensor_stop_indices = compute_subtensor_stop_indices(
            x_subtensor_shapes)

        x_start_index = torch.from_numpy(
            x_subtensor_start_indices[_slice].squeeze())
        x_stop_index = torch.from_numpy(
            x_subtensor_stop_indices[_slice].squeeze())

        y_subtensor_shapes = compute_subtensor_shapes_balanced(
            global_output_tensor_structure, self.P_x.shape)
        y_subtensor_start_indices = compute_subtensor_start_indices(
            y_subtensor_shapes)
        y_subtensor_stop_indices = compute_subtensor_stop_indices(
            y_subtensor_shapes)

        y_start_index = torch.from_numpy(
            y_subtensor_start_indices[_slice].squeeze())
        y_stop_index = torch.from_numpy(
            y_subtensor_stop_indices[_slice].squeeze())

        x_start_index = self._compute_needed_start(
            y_start_index, global_input_tensor_structure.shape,
            global_output_tensor_structure.shape, self.scale_factor, self.mode,
            self.align_corners)

        x_stop_index = self._compute_needed_stop(
            y_stop_index - 1, global_input_tensor_structure.shape,
            global_output_tensor_structure.shape, self.scale_factor, self.mode,
            self.align_corners)

        self.interp_layer = Interpolate(x_start_index,
                                        x_stop_index,
                                        global_input_tensor_structure.shape,
                                        y_start_index,
                                        y_stop_index,
                                        global_output_tensor_structure.shape,
                                        scale_factor=self.scale_factor,
                                        mode=self.mode,
                                        align_corners=self.align_corners)
示例#6
0
    def _distdl_module_setup(self, input):
        r"""Distributed (feature) convolution module setup function.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        self._distdl_is_setup = True
        self._input_tensor_structure = TensorStructure(input[0])

        if not self.P_x.active:
            return

        if self.serial:
            return

        # Compute global and local shapes with padding
        x_global_structure = \
            self._distdl_backend.assemble_global_tensor_structure(input[0], self.P_x)
        x_local_structure = TensorStructure(input[0])
        x_global_shape = x_global_structure.shape
        x_local_shape = x_local_structure.shape
        x_global_shape_after_pad = x_global_shape + 2 * self.global_padding
        x_local_shape_after_pad = x_local_shape + np.sum(
            self.local_padding, axis=1, keepdims=False)
        x_local_structure_after_pad = TensorStructure(input[0])
        x_local_structure_after_pad.shape = x_local_shape_after_pad

        # We need to compute the halos with respect to the explicit padding.
        # So, we assume the padding is already added, then compute the halo regions.
        compute_subtensor_shapes_unbalanced = \
            self._distdl_backend.tensor_decomposition.compute_subtensor_shapes_unbalanced
        subtensor_shapes = \
            compute_subtensor_shapes_unbalanced(x_local_structure_after_pad, self.P_x)

        # Using that information, we can get there rest of the halo information
        exchange_info = self._compute_exchange_info(
            x_global_shape_after_pad,
            self.kernel_size,
            self.stride,
            self._expand_parameter(0),
            self.dilation,
            self.P_x.active,
            self.P_x.shape,
            self.P_x.index,
            subtensor_shapes=subtensor_shapes)
        halo_shape = exchange_info[0]
        recv_buffer_shape = exchange_info[1]
        send_buffer_shape = exchange_info[2]
        needed_ranges = exchange_info[3]

        self.halo_shape = halo_shape

        # We can also set up part of the halo layer.
        self.halo_layer = HaloExchange(self.P_x,
                                       halo_shape,
                                       recv_buffer_shape,
                                       send_buffer_shape,
                                       buffer_manager=self.buffer_manager)

        # We have to select out the "unused" entries.  Sometimes there can
        # be "negative" halos.
        self.needed_slices = assemble_slices(needed_ranges[:, 0],
                                             needed_ranges[:, 1])