Ejemplo n.º 1
0
    def __init__(self, P_x, P_y, P_w, in_features, out_features, bias=True):

        super(DistributedLinear, self).__init__()

        # P_x ~ 1 X P_fi
        self.P_x = P_x
        # P_y ~ 1 X P_fo
        self.P_y = P_y
        # P_w ~ P_fo X P_fi
        self.P_w = P_w

        self.bias = bias

        self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True)

        if self.P_w.active:
            local_in_features = compute_subshape(P_w.shape[1], P_w.index[1],
                                                 in_features)
            local_out_features = compute_subshape(P_w.shape[0], P_w.index[0],
                                                  out_features)
            # On column 0, use the specified bias, otherwise no bias to
            # prevent double counting
            bias = self.bias if (self.P_w.index[-1] == 0) else False
            self.sublinear = torch.nn.Linear(local_in_features[0],
                                             local_out_features[0],
                                             bias=bias)

        self.y_sum_reduce = SumReduce(self.P_w,
                                      self.P_y,
                                      transpose_src=True,
                                      preserve_batch=True)
Ejemplo n.º 2
0
def test_transpose_adjoint(barrier_fence_fixture,
                           comm_split_fixture,
                           P_x_ranks, P_x_shape,
                           P_y_ranks, P_y_shape,
                           x_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    # The global tensor size is the same for x and y
    layer = DistributedTranspose(P_x, P_y, preserve_batch=False)

    # Forward Input
    x = zero_volume_tensor()
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape,
                                         P_x.index,
                                         x_global_shape)
        x = torch.Tensor(np.random.randn(*x_local_shape))
    x.requires_grad = True

    # Adjoint Input
    dy = zero_volume_tensor()
    if P_y.active:
        y_local_shape = compute_subshape(P_y.shape,
                                         P_y.index,
                                         x_global_shape)
        dy = torch.Tensor(np.random.randn(*y_local_shape))

    # y = F @ x
    y = layer(x)

    # dx = F* @ dy
    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)
Ejemplo n.º 3
0
    def _compute_halo_shape(self,
                            shape,
                            index,
                            x_global_shape,
                            kernel_size,
                            stride,
                            padding,
                            dilation,
                            require_nonnegative=True):

        x_global_shape = np.asarray(x_global_shape)

        x_local_shape = compute_subshape(shape, index, x_global_shape)
        x_local_start_index = compute_start_index(shape, index, x_global_shape)

        # formula from pytorch docs for maxpool
        y_global_shape = self._compute_out_shape(x_global_shape, kernel_size,
                                                 stride, padding, dilation)

        y_local_shape = compute_subshape(shape, index, y_global_shape)
        y_local_start_index = compute_start_index(shape, index, y_global_shape)

        y_local_left_global_index = y_local_start_index
        x_local_left_global_index_needed = self._compute_min_input_range(y_local_left_global_index,
                                                                         kernel_size,
                                                                         stride,
                                                                         padding,
                                                                         dilation)
        # Clamp to the boundary
        x_local_left_global_index_needed = np.maximum(np.zeros_like(x_global_shape),
                                                      x_local_left_global_index_needed)

        y_local_right_global_index = y_local_start_index + y_local_shape - 1
        x_local_right_global_index_needed = self._compute_max_input_range(y_local_right_global_index,
                                                                          kernel_size,
                                                                          stride,
                                                                          padding,
                                                                          dilation)
        # Clamp to the boundary
        x_local_right_global_index_needed = np.minimum(x_global_shape - 1,
                                                       x_local_right_global_index_needed)

        # Compute the actual ghost values
        x_local_left_halo_shape = x_local_start_index - x_local_left_global_index_needed
        x_local_stop_index = x_local_start_index + x_local_shape - 1
        x_local_right_halo_shape = x_local_right_global_index_needed - x_local_stop_index

        # Make sure the halos are always positive, so we get valid buffer shape
        if require_nonnegative:
            x_local_left_halo_shape = np.maximum(x_local_left_halo_shape, 0)
            x_local_right_halo_shape = np.maximum(x_local_right_halo_shape, 0)

        return np.hstack([x_local_left_halo_shape, x_local_right_halo_shape]).reshape(2, -1).T
Ejemplo n.º 4
0
def test_simple_conv2d_adjoint_weight(barrier_fence_fixture,
                                      comm_split_fixture,
                                      P_x_ranks, P_x_shape,
                                      x_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.conv_feature import DistributedFeatureConv2d
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    x_global_shape = np.asarray(x_global_shape)

    layer = DistributedFeatureConv2d(P_x,
                                     in_channels=x_global_shape[1],
                                     out_channels=10,
                                     kernel_size=[3, 3], bias=False)

    x = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape,
                                         P_x.index,
                                         x_global_shape)
        x = torch.randn(*x_local_shape)
    x.requires_grad = True

    y = layer(x)

    dy = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        dy = torch.randn(*y.shape)

    y.backward(dy)

    W = zero_volume_tensor()
    dW = zero_volume_tensor()
    if P_x.active:
        W = layer.weight.detach()
        dW = layer.weight.grad.detach()

    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, W, dW, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Ejemplo n.º 5
0
def test_linear_adjoint_input(barrier_fence_fixture, comm_split_fixture,
                              P_x_ranks, P_x_shape, P_y_ranks, P_y_shape,
                              P_w_ranks, P_w_shape, x_global_shape,
                              y_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.linear import DistributedLinear
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    P_w_base = P_world.create_partition_inclusive(P_w_ranks)
    P_w = P_w_base.create_cartesian_topology_partition(P_w_shape)

    x_global_shape = np.asarray(x_global_shape)
    y_global_shape = np.asarray(y_global_shape)

    layer = DistributedLinear(P_x,
                              P_y,
                              P_w,
                              x_global_shape[1],
                              y_global_shape[1],
                              bias=False)

    x = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape)
        x = torch.Tensor(np.random.randn(*x_local_shape))
    x.requires_grad = True

    y = layer(x)

    dy = zero_volume_tensor(x_global_shape[0])
    if P_y.active:
        dy = torch.Tensor(np.random.randn(*y.shape))

    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)
Ejemplo n.º 6
0
def test_excepts_mismatched_nondivisible_tensor(barrier_fence_fixture,
                                                comm_split_fixture):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.repartition import Repartition
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # A tensor with size 1 in a dimension cannot be partitioned in that
    # dimension.  (See last dimension of output and tensor.)
    in_shape = (1, 4, 1, 1)
    out_shape = (1, 1, 1, 2)
    x_global_shape = np.array([1, 16, 5, 1])

    in_size = np.prod(in_shape)
    out_size = np.prod(out_shape)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size))
    P_x = P_x_base.create_cartesian_topology_partition(in_shape)

    P_y_base = P_world.create_partition_inclusive(
        np.arange(P_world.size - out_size, P_world.size))
    P_y = P_y_base.create_cartesian_topology_partition(out_shape)

    with pytest.raises(ValueError) as e_info:  # noqa: F841
        layer = Repartition(P_x, P_y)
        layer = layer.to(device)

        # Forward Input
        x = zero_volume_tensor(device=device)
        if P_x.active:
            x_local_shape = compute_subshape(P_x.shape, P_x.index,
                                             x_global_shape)
            x = torch.randn(*x_local_shape, device=device)
        x.requires_grad = True

        layer(x)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
Ejemplo n.º 7
0
def test_simple_conv2d_shape(barrier_fence_fixture,
                             comm_split_fixture,
                             P_x_ranks, P_x_shape,
                             x_global_shape,
                             y_local_shape,
                             padding):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.conv_feature import DistributedFeatureConv2d
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    x_global_shape = np.asarray(x_global_shape)

    layer = DistributedFeatureConv2d(P_x,
                                     in_channels=x_global_shape[1],
                                     out_channels=10,
                                     kernel_size=[3, 3],
                                     padding=padding,
                                     bias=False)

    x = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape,
                                         P_x.index,
                                         x_global_shape)
        x = torch.zeros(*x_local_shape)
    x.requires_grad = True

    y = layer(x)

    if P_x.active:
        assert(np.array_equal(np.array(y.shape), np.asarray(y_local_shape)))

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Ejemplo n.º 8
0
    def __init__(self, P_x, P_y, P_w, in_features, out_features, bias=True):

        super(DistributedLinear, self).__init__()

        # P_x ~ 1 X P_fi
        self.P_x = P_x
        # P_y ~ 1 X P_fo
        self.P_y = P_y
        # P_w ~ P_fo X P_fi
        self.P_w = P_w

        # Bias flag
        self.bias = bias

        # Broadcast layer in the x-tensor
        self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True)

        # Each worker in P_W computes its own portion of the weight tensor and then
        # stores its own PyTorch Linear layer.  Only the 0th column of the tensor
        # also stores a bias.
        if self.P_w.active:
            local_in_features = compute_subshape(P_w.shape[1], P_w.index[1],
                                                 in_features)
            local_out_features = compute_subshape(P_w.shape[0], P_w.index[0],
                                                  out_features)
            # On column 0, use the specified bias, otherwise no bias to
            # prevent double counting
            bias = self.bias if (self.P_w.index[-1] == 0) else False
            self.sublinear = torch.nn.Linear(local_in_features[0],
                                             local_out_features[0],
                                             bias=bias)

        # Sum-reduce layer to get the y-tensor
        self.y_sum_reduce = SumReduce(self.P_w,
                                      self.P_y,
                                      transpose_src=True,
                                      preserve_batch=True)
Ejemplo n.º 9
0
def test_excepts_mismatched_output_partition_tensor(barrier_fence_fixture,
                                                    comm_split_fixture):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Output partition rank must match tensor rank
    in_shape = (4, 1, 1)
    out_shape = (1, 1, 1, 2)
    x_global_shape = np.array([16, 5, 5])

    in_size = np.prod(in_shape)
    out_size = np.prod(out_shape)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(np.arange(0, in_size))
    P_x = P_x_base.create_cartesian_topology_partition(in_shape)

    P_y_base = P_world.create_partition_inclusive(np.arange(P_world.size-out_size, P_world.size))
    P_y = P_y_base.create_cartesian_topology_partition(out_shape)

    with pytest.raises(ValueError) as e_info:  # noqa: F841
        layer = DistributedTranspose(P_x, P_y)

        # Forward Input
        x = zero_volume_tensor()
        if P_x.active:
            x_local_shape = compute_subshape(P_x.shape,
                                             P_x.index,
                                             x_global_shape)
            x = torch.Tensor(np.random.randn(*x_local_shape))
        x.requires_grad = True

        layer(x)
Ejemplo n.º 10
0
def test_linear_adjoint_bias(barrier_fence_fixture, comm_split_fixture,
                             P_x_ranks, P_x_shape, P_y_ranks, P_y_shape,
                             P_w_ranks, P_w_shape, x_global_shape,
                             y_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.linear import DistributedLinear
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    P_w_base = P_world.create_partition_inclusive(P_w_ranks)
    P_w = P_w_base.create_cartesian_topology_partition(P_w_shape)

    x_global_shape = np.asarray(x_global_shape)
    y_global_shape = np.asarray(y_global_shape)

    layer = DistributedLinear(P_x,
                              P_y,
                              P_w,
                              x_global_shape[1],
                              y_global_shape[1],
                              bias=True)

    x = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape)
        # For this test, we are only testing to see if the adjoint works
        # correctly for the bias term.  But the adjoint test only works on the
        # Jacobian of the linear layer.  The Jacobian block for b is 0 for x and
        # W, so killing x makes the forward operator equal to its Jacobian and
        # we can test to see that adjoint is computed correctly.
        x = torch.zeros(*x_local_shape)
    x.requires_grad = True

    y = layer(x)

    dy = zero_volume_tensor(x_global_shape[0])
    if P_y.active:
        dy = torch.Tensor(np.random.randn(*y.shape))

    y.backward(dy)

    b = zero_volume_tensor()
    db = zero_volume_tensor()
    if P_w.active and P_w.index[-1] == 0:
        b = layer.sublinear.bias.detach()
        db = layer.sublinear.bias.grad.detach()

    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, b, db, y, dy)
Ejemplo n.º 11
0
    def _distdl_module_setup(self, input):
        r"""Transpose module setup function.

        Constructs the necessary buffers and meta information about outbound
        and inbound copies to each worker.

        This function is called every time something changes in the input
        tensor structure.  It should not be called manually.

        Parameters
        ----------
        input :
            Tuple of forward inputs.  See
            `torch.nn.Module.register_forward_pre_hook` for more details.

        """

        self._distdl_is_setup = True
        self._input_tensor_structure.fill_from_tensor(input[0])

        # If we are not an active worker, do nothing.
        if not self.P_union.active:
            return

        self.input_tensor_structure = TensorStructure(input[0])

        self.global_input_tensor_structure = \
            self._distdl_backend.assemble_global_tensor_structure(self.input_tensor_structure,
                                                                  self.P_x,
                                                                  self.P_union)
        x_global_shape = self.global_input_tensor_structure.shape

        if self.P_y.active:
            self.output_tensor_structure.shape = compute_subshape(
                self.P_y.shape, self.P_y.index, x_global_shape)

        tensor_dim = len(x_global_shape)

        if len(self.P_x_shape) != tensor_dim:
            raise ValueError(
                f"Input partition mush have same dimension "
                f"({len(self.P_x_shape)}) as input tensor rank ({tensor_dim})."
            )

        if len(self.P_y_shape) != tensor_dim:
            raise ValueError(
                f"Output partition mush have same dimension "
                f"({len(self.P_y_shape)}) as input tensor rank ({tensor_dim})."
            )

        if 1 in x_global_shape[x_global_shape != self.P_y_shape]:
            raise ValueError(
                f"Input tensor must not be size 1 "
                f"({x_global_shape}) in a dimension where "
                f"output partition is other than 1 ({self.P_y_shape}).")

        # Get the collective input lengths and origins. This may be load
        # balanced or it may not be.  Therefore we will always assume is is
        # not load balanced and just build the subshape tensor manually.
        # This output is needed everywhere so it goes to P_union.
        compute_subtensor_shapes_unbalanced = \
            self._distdl_backend.tensor_decomposition.compute_subtensor_shapes_unbalanced
        x_subtensor_shapes = compute_subtensor_shapes_unbalanced(
            self.input_tensor_structure, self.P_x, self.P_union)

        # Get the collective output lengths and origins. This will always be
        # load balanced, so we can infer the subshape tensor from the global
        # tensor shape and the shape of P_y.  At this point, every worker in
        # P_union has both of these pieces of information, so we can build it
        # with no communication.

        y_subtensor_shapes = compute_subtensor_shapes_balanced(
            self.global_input_tensor_structure, self.P_y_shape)

        # Given all subtensor shapes, we can compute the start and stop indices
        # for each partition.

        x_subtensor_start_indices = compute_subtensor_start_indices(
            x_subtensor_shapes)
        x_subtensor_stop_indices = compute_subtensor_stop_indices(
            x_subtensor_shapes)

        y_subtensor_start_indices = compute_subtensor_start_indices(
            y_subtensor_shapes)
        y_subtensor_stop_indices = compute_subtensor_stop_indices(
            y_subtensor_shapes)

        # We only need to move data to the output partition if we actually
        # have input data.  It is possible to have both input and output data,
        # either input or output data, or neither.  Hence the active guard.
        if self.P_x.active:
            x_slice = tuple([slice(i, i + 1)
                             for i in self.P_x.index] + [slice(None)])
            x_start_index = x_subtensor_start_indices[x_slice].squeeze()
            x_stop_index = x_subtensor_stop_indices[x_slice].squeeze()

            # Compute our overlaps for each output subpartition.
            for rank, P_y_index in enumerate(range_index(self.P_y_shape)):

                y_slice = tuple([slice(i, i + 1)
                                 for i in P_y_index] + [slice(None)])
                y_start_index = y_subtensor_start_indices[y_slice].squeeze()
                y_stop_index = y_subtensor_stop_indices[y_slice].squeeze()

                sl = compute_subtensor_intersection_slice(
                    x_start_index, x_stop_index, y_start_index, y_stop_index)

                if sl is not None:
                    sh = compute_nd_slice_shape(sl)
                    # If it is a self-copy, mark it so we don't have to create
                    # a potentially large buffer
                    if self.P_y.active and np.all(P_y_index == self.P_y.index):
                        partner = "self"
                    # Otherwise, reverse the mapping to get the output
                    # partner's rank in the common partition.
                    else:
                        partner = np.where(self.P_y_ranks == rank)[0][0]

                    self.P_x_to_y_overlaps.append((sl, sh, partner))

                else:
                    self.P_x_to_y_overlaps.append((None, None, None))

        # We only need to obtain data from the input partition if we actually
        # have output data.
        if self.P_y.active:
            y_slice = tuple([slice(i, i + 1)
                             for i in self.P_y.index] + [slice(None)])
            y_start_index = y_subtensor_start_indices[y_slice].squeeze()
            y_stop_index = y_subtensor_stop_indices[y_slice].squeeze()

            # Compute our overlaps for each input subpartition.
            for rank, P_x_index in enumerate(range_index(self.P_x_shape)):

                x_slice = tuple([slice(i, i + 1)
                                 for i in P_x_index] + [slice(None)])
                x_start_index = x_subtensor_start_indices[x_slice].squeeze()
                x_stop_index = x_subtensor_stop_indices[x_slice].squeeze()

                sl = compute_subtensor_intersection_slice(
                    y_start_index, y_stop_index, x_start_index, x_stop_index)

                if sl is not None:
                    sh = compute_nd_slice_shape(sl)
                    # If it is a self-copy, mark it so we don't have to create
                    # a potentially large buffer
                    if self.P_x.active and np.all(P_x_index == self.P_x.index):
                        partner = "self"
                    # Otherwise, reverse the mapping to get the output
                    # partner's rank in the common partition.
                    else:
                        partner = np.where(self.P_x_ranks == rank)[0][0]

                    self.P_y_to_x_overlaps.append((sl, sh, partner))

                else:
                    self.P_y_to_x_overlaps.append((None, None, None))

        buffs = self.allocate_transpose_buffers(
            self.buffer_manager, self.P_x_to_y_overlaps,
            self.P_y_to_x_overlaps, self.global_input_tensor_structure.dtype)
        self.P_x_to_y_buffers = buffs[0]
        self.P_y_to_x_buffers = buffs[1]
Ejemplo n.º 12
0
    def _compute_exchange_info(self,
                               x_global_shape,
                               kernel_size,
                               stride,
                               padding,
                               dilation,
                               partition_active,
                               partition_shape,
                               partition_index):

        if not partition_active:
            return None, None, None, None

        dim = len(partition_shape)

        x_global_shape = np.atleast_1d(x_global_shape)
        kernel_size = np.atleast_1d(kernel_size)
        stride = np.atleast_1d(stride)
        padding = np.atleast_1d(padding)
        dilation = np.atleast_1d(dilation)

        def compute_lpad_length(array):
            return len(x_global_shape) - len(array)

        kernel_size = np.pad(kernel_size,
                             pad_width=(compute_lpad_length(kernel_size), 0),
                             mode='constant',
                             constant_values=1)
        stride = np.pad(stride,
                        pad_width=(compute_lpad_length(stride), 0),
                        mode='constant',
                        constant_values=1)
        padding = np.pad(padding,
                         pad_width=(compute_lpad_length(padding), 0),
                         mode='constant',
                         constant_values=0)
        dilation = np.pad(dilation,
                          pad_width=(compute_lpad_length(dilation), 0),
                          mode='constant',
                          constant_values=1)

        halo_shape = self._compute_halo_shape(partition_shape,
                                              partition_index,
                                              x_global_shape,
                                              kernel_size,
                                              stride,
                                              padding,
                                              dilation)

        recv_buffer_shape = halo_shape.copy()

        send_buffer_shape = np.zeros_like(halo_shape)

        for i in range(dim):
            lindex = [x - 1 if j == i else x for j, x in enumerate(partition_index)]
            nhalo = self._compute_halo_shape(partition_shape,
                                             lindex,
                                             x_global_shape,
                                             kernel_size,
                                             stride,
                                             padding,
                                             dilation)
            # If I have a left neighbor, my left send buffer size is my left
            # neighbor's right halo size
            if(lindex[i] > -1):
                send_buffer_shape[i, 0] = nhalo[i, 1]

            rindex = [x + 1 if j == i else x for j, x in enumerate(partition_index)]
            nhalo = self._compute_halo_shape(partition_shape,
                                             rindex,
                                             x_global_shape,
                                             kernel_size,
                                             stride,
                                             padding,
                                             dilation)
            # If I have a right neighbor, my right send buffer size is my right
            # neighbor's left halo size
            if(rindex[i] < partition_shape[i]):
                send_buffer_shape[i, 1] = nhalo[i, 0]

        x_local_shape = compute_subshape(partition_shape, partition_index, x_global_shape)
        halo_shape_with_negatives = self._compute_halo_shape(partition_shape,
                                                             partition_index,
                                                             x_global_shape,
                                                             kernel_size,
                                                             stride,
                                                             padding,
                                                             dilation,
                                                             require_nonnegative=False)
        needed_ranges = self._compute_needed_ranges(x_local_shape, halo_shape_with_negatives)

        halo_shape = halo_shape.astype(int)
        needed_ranges = needed_ranges.astype(int)

        return halo_shape, recv_buffer_shape, send_buffer_shape, needed_ranges
Ejemplo n.º 13
0
    def forward(ctx, input, P_union, x_global_shape, P_x, in_data, in_buffers,
                P_y, out_data, out_buffers, preserve_batch, dtype):

        ctx.P_union = P_union
        ctx.x_global_shape = x_global_shape

        ctx.P_x = P_x
        ctx.in_data = in_data
        ctx.in_buffers = in_buffers

        ctx.P_y = P_y
        ctx.out_data = out_data
        ctx.out_buffers = out_buffers

        ctx.preserve_batch = preserve_batch

        ctx.dtype = dtype

        input_requires_grad = False
        # By design, P_x is always first in the union
        if P_union.active:
            if P_x.rank == 0:
                input_requires_grad = input.requires_grad
                P_union.comm.Bcast(np.array([1 if input_requires_grad else 0]),
                                   root=0)
            else:
                irg = np.array([0], dtype=np.int)
                P_union.comm.Bcast(irg, root=0)
                input_requires_grad = bool(irg[0] == 1)

        ctx.input_requires_grad = input_requires_grad

        requests = []

        # Default everyone to output nothing
        if preserve_batch:
            output = zero_volume_tensor(input.shape[0])
        else:
            output = zero_volume_tensor()

        # If I am getting data, recv my output parts
        recv_count = 0
        if P_y.active:
            for (sl, sz, partner), buff in zip(out_data, out_buffers):
                if buff is not None:
                    req = P_union.comm.Irecv(buff, source=partner, tag=111)
                    requests.append(req)
                else:
                    # We add this if there is no recv so that the indices of
                    # the requests array match the indices of out_data and
                    # out_buffers.
                    requests.append(MPI.REQUEST_NULL)
                recv_count += 1

        # If I have data to share, pack and send my input parts
        send_count = 0
        if P_x.active:
            input_numpy = input.detach().numpy()
            for (sl, sz, partner), buff in zip(in_data, in_buffers):
                if buff is not None:
                    np.copyto(buff, input_numpy[tuple(sl)].ravel())
                    req = P_union.comm.Isend(buff, dest=partner, tag=111)
                    requests.append(req)
                else:
                    # We add this for symmetry, but don't really need it.
                    requests.append(MPI.REQUEST_NULL)
                send_count += 1

        # We do this after the sends so that they can get started before local
        # allocations.
        if P_y.active:
            index = P_y.index
            y_local_shape = compute_subshape(P_y.shape, index, x_global_shape)
            # TODO(#25): The dtype should not be fixed, but correcting this is
            #            a thing that needs to be resolved globally.
            output = np.zeros(y_local_shape, dtype=dtype)

        # Unpack the received data as it arrives
        completed_count = 0
        while (completed_count < len(requests)):
            status = MPI.Status()
            index = MPI.Request.Waitany(requests, status)

            # In MPI, we don't get the index out if the request is an
            # instance of MPI.REQUEST_NULL, instead MPI.UNDEFINED is returned.
            if P_y.active and index < recv_count and index != MPI.UNDEFINED:
                # Unpack my output parts
                sl, sz, partner = out_data[index]
                buff = out_buffers[index]
                if buff is not None:
                    sh = output[tuple(sl)].shape
                    np.copyto(output[tuple(sl)], buff.reshape(sh))

            completed_count += 1

        if P_y.active:
            output = torch.from_numpy(output)
            output.requires_grad = input_requires_grad

        return output
Ejemplo n.º 14
0
def test_general_conv1d_adjoint_bias(barrier_fence_fixture, comm_split_fixture,
                                     P_x_ranks, P_x_shape, P_y_ranks,
                                     P_y_shape, P_w_ranks, P_w_shape,
                                     x_global_shape):

    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.conv_general import DistributedGeneralConv1d
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    P_w_base = P_world.create_partition_inclusive(P_w_ranks)
    P_w = P_w_base.create_cartesian_topology_partition(P_w_shape)

    x_global_shape = np.asarray(x_global_shape)

    layer = DistributedGeneralConv1d(P_x,
                                     P_y,
                                     P_w,
                                     in_channels=x_global_shape[1],
                                     out_channels=10,
                                     kernel_size=[3],
                                     bias=True)

    x = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape)
        x = torch.zeros(*x_local_shape)
    x.requires_grad = True

    y = layer(x)

    dy = zero_volume_tensor(x_global_shape[0])
    if P_y.active:
        dy = torch.randn(*y.shape)

    y.backward(dy)

    b = zero_volume_tensor()
    db = zero_volume_tensor()
    if P_w.active and layer.stores_bias:
        b = layer.bias.detach()
        db = layer.bias.grad.detach()

    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, b, db, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
    P_w_base.deactivate()
    P_w.deactivate()
Ejemplo n.º 15
0
def test_repartition_dtype(barrier_fence_fixture, comm_split_fixture, dtype,
                           test_backward, P_x_ranks, P_x_shape, P_y_ranks,
                           P_y_shape, x_global_shape):

    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.repartition import Repartition
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    # The global tensor size is the same for x and y
    layer = Repartition(P_x, P_y, preserve_batch=False)
    layer = layer.to(device)

    # Forward Input
    x = zero_volume_tensor(dtype=dtype, device=device)
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape)
        x = 10 * torch.randn(*x_local_shape, device=device).to(dtype)

    x.requires_grad = test_backward
    # y = F @ x
    y = layer(x)
    if P_y.active:
        assert y.dtype == dtype

    if test_backward:
        # Adjoint Input
        dy = zero_volume_tensor(dtype=dtype, device=device)
        if P_y.active:
            y_local_shape = compute_subshape(P_y.shape, P_y.index,
                                             x_global_shape)
            dy = 10 * torch.randn(*y_local_shape, device=device).to(dtype)

        # dx = F* @ dy
        y.backward(dy)
        dx = x.grad
        if P_x.active:
            assert dx.dtype == dtype

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
Ejemplo n.º 16
0
def test_halo_exchange_adjoint(barrier_fence_fixture, comm_split_fixture,
                               P_x_ranks, P_x_shape, x_global_shape, dtype,
                               kernel_size, stride, padding, dilation,
                               MockKernelStyle):
    import numpy as np
    import torch
    import torch.nn.functional as F

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.halo_exchange import HaloExchange
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import distdl_padding_to_torch_padding
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    x_global_shape = np.asarray(x_global_shape)
    kernel_size = np.asarray(kernel_size)
    stride = np.asarray(stride)
    padding = np.asarray(padding)
    dilation = np.asarray(dilation)

    halo_shape = None
    recv_buffer_shape = None
    send_buffer_shape = None
    if P_x.active:
        mockup_layer = MockKernelStyle()
        exchange_info = mockup_layer._compute_exchange_info(
            x_global_shape, kernel_size, stride, padding, dilation, P_x.active,
            P_x.shape, P_x.index)
        halo_shape = exchange_info[0]
        recv_buffer_shape = exchange_info[1]
        send_buffer_shape = exchange_info[2]

    halo_layer = HaloExchange(P_x, halo_shape, recv_buffer_shape,
                              send_buffer_shape)
    halo_layer = halo_layer.to(device)

    x = zero_volume_tensor(x_global_shape[0], device=device)
    dy = zero_volume_tensor(x_global_shape[0], device=device)
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape, P_x.index, x_global_shape)

        padding = distdl_padding_to_torch_padding(halo_shape)

        x = torch.randn(*x_local_shape, device=device).to(dtype)

        # Pad the input with the halo space.  We are only testing the behavior of
        # the halo exchange so the input must be padded before we can do anything.
        x = F.pad(x, pad=padding, mode="constant", value=0)

        # dy is also padded, but we wanted it to start with data inside it.
        dy = torch.randn(*x.shape, device=device).to(dtype)

    x.requires_grad = True

    # Halo Exchange (both fwd and adj) is in-place.  So, we copy the input
    # data and save the original for the adjoint test. Because it is in-place,
    # the clones themselves are modified.  This also prevents issues with us
    # in-place operations on leaf-nodes.

    x_clone = x.clone()
    dy_clone = dy.clone()

    # x_clone is be modified in place by halo_layer, but we assign y to
    # reference it for clarity.  y and x_clone are the same object.
    y = halo_layer(x_clone)

    # dy_clone is modified in place by halo_layer-adjoint, but we assign dx to
    # reference it for clarity.  dx and dy_clone are the same object.
    # dx is not in the grad field as you might expect because the operation is
    # in-place.
    y.backward(dy_clone)
    dx = dy_clone

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Ejemplo n.º 17
0
    def __init__(self, P_x, P_y, P_w,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 padding_mode='zeros',
                 dilation=1,
                 groups=1,
                 bias=True,
                 buffer_manager=None):

        super(DistributedGeneralConvBase, self).__init__()

        # P_x is 1    x P_ci x P_d-1 x ... x P_0
        self.P_x = P_x
        # P_y is 1    x P_co x P_d-1 x ... x P_0
        self.P_y = P_y
        # P_w is P_co x P_ci x P_d-1 x ... x P_0
        self.P_w = P_w

        # Back-end specific buffer manager for economic buffer allocation
        if buffer_manager is None:
            buffer_manager = self._distdl_backend.BufferManager()
        elif type(buffer_manager) is not self._distdl_backend.BufferManager:
            raise ValueError("Buffer manager type does not match backend.")
        self.buffer_manager = buffer_manager

        # Even inactive workers need some partition union
        self.P_union = self._distdl_backend.Partition()
        if not (self.P_x.active or
                self.P_y.active or
                self.P_w.active):
            return

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = self._expand_parameter(kernel_size)
        self.stride = self._expand_parameter(stride)
        self.padding = self._expand_parameter(padding)
        self.padding_mode = padding_mode
        self.dilation = self._expand_parameter(dilation)
        self.groups = groups
        self.use_bias = bias

        # This guarantees that P_union rank 0 has the kernel size, stride,
        # padding, and dilation factors
        P_union_temp = P_w.create_partition_union(P_x)
        self.P_union = P_union_temp.create_partition_union(P_y)

        # Release the temporary resources
        P_union_temp.deactivate()

        # Ensure that all workers have the full size and structure of P_w
        P_w_shape = None
        if self.P_union.rank == 0:
            P_w_shape = np.array(P_w.shape, dtype=np.int)
        P_w_shape = self.P_union.broadcast_data(P_w_shape, root=0)

        P_co = P_w_shape[0]
        P_ci = P_w_shape[1]
        P_channels = [P_co, P_ci]

        # Ensure that P_x and P_w are correctly aligned.  We also produce a
        # new P_x that is shaped like 1 x P_ci x P_d-1 x ... x P_0, to assist
        # with broadcasts.
        P_x_new_shape = []
        if self.P_x.active:
            if(np.any(P_x.shape[2:] != P_w_shape[2:])):
                raise ValueError("Spatial components of P_x and P_w must match.")
            if P_w_shape[1] != P_x.shape[1]:
                raise ValueError("Index 2 of P_w dimension must match input channel partition.")
            P_x_new_shape = list(P_x.shape)
            P_x_new_shape.insert(1, 1)
            # Currently a hack, removing the batch dimension because P_w does
            # not have one. This is OK because we assume there are no partitions
            # in the batch dimension.
            P_x_new_shape = np.asarray(P_x_new_shape[1:], dtype=int)

        # For the purposes of this layer, we re-cast P_x to have the extra
        # dimension.  This has no impact outside of the layer or on the results.
        self.P_x = self.P_x.create_cartesian_topology_partition(P_x_new_shape)

        # Ensure that P_y and P_w are correctly aligned.  We also produce a
        # new P_y that is shaped like 1 x P_ci x P_d-1 x ... x P_0, to assist
        # with broadcasts.
        P_y_new_shape = []
        if self.P_y.active:
            if(np.any(P_y.shape[2:] != P_w_shape[2:])):
                raise ValueError("Spatial components of P_y and P_w must match.")
            if P_w_shape[0] != P_y.shape[1]:
                raise ValueError("Index 1 of P_w dimension must match output channel partition.")
            P_y_new_shape = list(P_y.shape)
            P_y_new_shape.insert(2, 1)
            # Currently a hack, removing the batch dimension because P_w does
            # not have one. This is OK because we assume there are no partitions
            # in the batch dimension.
            P_y_new_shape = np.asarray(P_y_new_shape[1:], dtype=int)

        # For the purposes of this layer, we re-cast P_x to have the extra
        # dimension.  This has no impact outside of the layer or on the results.
        self.P_y = self.P_y.create_cartesian_topology_partition(P_y_new_shape)

        P_spatial = P_w_shape[2:]

        self.serial = self.P_w.size == 1

        if self.serial:
            self.conv_layer = self.TorchConvType(in_channels=in_channels,
                                                 out_channels=out_channels,
                                                 kernel_size=self.kernel_size,
                                                 stride=self.stride,
                                                 padding=self.padding,
                                                 padding_mode=self.padding_mode,
                                                 dilation=self.dilation,
                                                 groups=self.groups,
                                                 bias=self.use_bias)
            self.weight = self.conv_layer.weight
            self.bias = self.conv_layer.bias
            return

        # Need to figure out any padding necessary to handle global padding.
        # This is only on the input tensor.  The convolution will not use
        # any implicit padding, so the work partition does not need it.
        if self.P_x.active:
            dims = len(self.P_x.shape)

            # We will be using global padding to compute local padding,
            # so expand it to a numpy array
            global_padding = np.pad(self.padding,
                                    pad_width=(dims-len(self.padding), 0),
                                    mode='constant',
                                    constant_values=0)
            self.global_padding = global_padding

            pad_left_right = self.global_padding.reshape((dims, 1)) + np.zeros((dims, 2), dtype=np.int)
            self.local_padding = self._compute_local_padding(pad_left_right)

        # Workers can either store the learnable weights and bias, or they
        # need copies of it.
        self.receives_weight = False
        self.stores_weight = False
        self.receives_bias = False
        self.stores_bias = False

        # Determine root partitions, initialize weights there
        if self.P_w.active:
            # All of P_w always receives the weight
            self.receives_weight = True

            # This subset is taken to be the origin of the spartial component
            w_root_subset = []
            for i, c in enumerate(range_index(P_w.shape)):
                c = np.asarray(c)
                # Find the P_co x P_ci x 1 x ... x 1 subset to store the weights
                if np.all(c[2:] == 0):
                    w_root_subset.append(i)

            P_wr_base = self.P_w.create_partition_inclusive(w_root_subset)
            # ones are needed so the broadcast will work
            self.P_wr = P_wr_base.create_cartesian_topology_partition([P_co, P_ci] + [1]*len(P_spatial))
            self.stores_weight = self.P_wr.active

            # Release temporary resources
            P_wr_base.deactivate()

            b_subset = []
            for i, c in enumerate(range_index(P_w.shape)):
                c = np.asarray(c)
                # Find the P_co x 1 x P_0 x ... x P_D-1 subset that needs
                # biases in its calculation. This is everywhere that the input
                # channels is rank 0.
                if c[1] == 0:
                    b_subset.append(i)

            P_b_base = self.P_w.create_partition_inclusive(b_subset)
            self.P_b = P_b_base.create_cartesian_topology_partition([P_co] + [1] + list(P_spatial))
            self.receives_bias = self.P_b.active and self.use_bias

            # Release temporary resources
            P_b_base.deactivate()

            # Now find the subset of _that_ which actually stores the
            # learnable parameter.
            b_root_subset = []
            for i, c in enumerate(range_index(P_w.shape)):
                c = np.asarray(c)
            # Find the P_co x 1 x 1 x ... x 1 subset to store the biases
                if np.all(c[1:] == 0):
                    b_root_subset.append(i)

            P_br_base = self.P_w.create_partition_inclusive(b_root_subset)
            # ones are needed so the broadcast will work
            self.P_br = P_br_base.create_cartesian_topology_partition([P_co] + [1] + [1]*len(P_spatial))
            self.stores_bias = self.P_br.active and self.use_bias

            # Release temporary resources
            P_br_base.deactivate()

            # Correct the input arguments based on local properties
            # This ensures that the in and out channels are correctly shared.
            local_co, local_ci = compute_subshape(P_channels,
                                                  P_w.index[0:2],
                                                  [out_channels, in_channels])
            self.conv_layer = self.TorchConvType(in_channels=local_ci,
                                                 out_channels=local_co,
                                                 kernel_size=self.kernel_size,
                                                 stride=self.stride,
                                                 padding=0,
                                                 padding_mode='zeros',
                                                 dilation=self.dilation,
                                                 groups=groups,
                                                 bias=self.receives_bias)

            # If we store the weight it is a learnable parameter iff it is
            # learnable by default in the layer, which it is.
            if self.stores_weight:
                self.weight = torch.nn.Parameter(self.conv_layer.weight.detach())
            else:
                self.register_buffer('weight', zero_volume_tensor())
            # This always exists so we can copy the property
            self.weight.requires_grad = self.conv_layer.weight.requires_grad

            # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2
            new_weight = self.conv_layer.weight.detach() * 0
            new_weight.requires_grad = self.conv_layer.weight.requires_grad
            del self.conv_layer.weight
            self.conv_layer.weight = new_weight

            # If we store the bias, it is a learnable parameter iff it is
            # learnable by default in the layer, which is only true if it
            # exists.
            if self.stores_bias:
                self.bias = torch.nn.Parameter(self.conv_layer.bias.detach())
            else:
                if self.use_bias:
                    self.register_buffer('bias', zero_volume_tensor())
                else:
                    self.register_buffer('bias', None)
            # This does not always exist, but when it does we can copy the
            # property.
            if self.receives_bias:
                self.bias.requires_grad = self.conv_layer.bias.requires_grad

                # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2
                new_bias = self.conv_layer.bias.detach() * 0
                new_bias.requires_grad = self.conv_layer.bias.requires_grad
                del self.conv_layer.bias
                self.conv_layer.bias = new_bias

        else:
            # Workers not in P_w don't have a weight or bias.
            self.register_buffer('weight', zero_volume_tensor())
            if self.use_bias:
                self.register_buffer('bias', zero_volume_tensor())
            else:
                self.register_buffer('bias', None)

        # Now we need to share the kernel structure.  The size of the kernel
        # is always the spatial dimensions.
        self.conv_kernel_size = None
        self.conv_stride = None
        self.conv_padding = None
        self.conv_dilation = None

        # By construction, rank 0 of the union should always have all of this
        # information, because it will always construct a local conv layer. We
        # rely on the local conv layer to properly fill out this information
        # from the defaults.  This info is required for all workers on the
        # input and output partitions because it is needed to construct the
        # halos.  Rank 0 in the union shares it with everyone.
        if self.P_union.rank == 0:
            self.conv_kernel_size = np.array(self.conv_layer.kernel_size, dtype=np.int)
            self.conv_stride = np.array(self.conv_layer.stride, dtype=np.int)
            self.conv_padding = np.array(self.conv_layer.padding, dtype=np.int)
            self.conv_dilation = np.array(self.conv_layer.dilation, dtype=np.int)
        self.conv_kernel_size = self.P_union.broadcast_data(self.conv_kernel_size, root=0)
        self.conv_stride = self.P_union.broadcast_data(self.conv_stride, root=0)
        self.conv_padding = self.P_union.broadcast_data(self.conv_padding, root=0)
        self.conv_dilation = self.P_union.broadcast_data(self.conv_dilation, root=0)

        # We need to be able to remove some data from the input to the conv
        # layer but again need to defer.
        self.needed_slices = None

        # For the halo layer we also defer construction, so that we can have
        # the halo shape for the input.  The halo will allocate its own
        # buffers, but it needs this information at construction to be able
        # to do this in the pre-forward hook.
        self.halo_layer = None

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()

        # Some layers, those that require no information about the input
        # tensor to setup, can be built now.
        if P_w.active:
            self.w_broadcast = Broadcast(self.P_wr, self.P_w, preserve_batch=False)

        if self.receives_bias or self.stores_bias:
            self.b_broadcast = Broadcast(self.P_br, self.P_b, preserve_batch=False)

        self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True)
        self.y_sum_reduce = SumReduce(self.P_w, self.P_y, preserve_batch=True)
Ejemplo n.º 18
0
    def __init__(self,
                 P_x,
                 P_y,
                 P_w,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 padding_mode='zeros',
                 dilation=1,
                 groups=1,
                 bias=True,
                 *args,
                 **kwargs):

        super(DistributedChannelConvBase, self).__init__()

        # P_x is 1    x P_ci x 1 x ... x 1
        self.P_x = P_x
        # P_y is 1    x P_co x 1 x ... x 1
        self.P_y = P_y
        # P_w is P_co x P_ci x 1 x ... x 1
        self.P_w = P_w

        # Even inactive workers need some partition union
        P_union = self._distdl_backend.Partition()

        if not (self.P_x.active or self.P_y.active or self.P_w.active):
            return

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = self._expand_parameter(kernel_size)
        self.stride = self._expand_parameter(stride)
        self.padding = self._expand_parameter(padding)
        self.padding_mode = padding_mode
        self.dilation = self._expand_parameter(dilation)
        self.groups = groups
        self.use_bias = bias

        # This guarantees that P_union rank 0 has the kernel size, stride,
        # padding, and dilation factors
        P_union_temp = P_w.create_partition_union(P_x)
        P_union = P_union_temp.create_partition_union(P_y)

        # Ensure that all workers have the full size and structure of P_w
        P_w_shape = None
        if P_union.rank == 0:
            P_w_shape = np.array(P_w.shape, dtype=np.int)
        P_w_shape = P_union.broadcast_data(P_w_shape, root=0)

        # Release the temporary resources
        P_union_temp.deactivate()
        P_union.deactivate()

        P_co = P_w_shape[0]
        P_ci = P_w_shape[1]
        P_channels = [P_co, P_ci]

        # Ensure that P_x and P_w are correctly aligned.  We also produce a
        # new P_x that is shaped like 1 x P_ci x 1 x ... x 1, to assist with
        # broadcasts.
        P_x_new_shape = []
        if self.P_x.active:
            if (np.any(P_x.shape[2:] != P_w_shape[2:])):
                raise ValueError(
                    "Spatial components of P_x and P_w must match.")
            if (np.any(P_x.shape[2:] != np.ones(len(P_x.shape[2:])))):
                raise ValueError(
                    "Spatial components of P_x must be 1 x ... x 1.")
            if P_w_shape[1] != P_x.shape[1]:
                raise ValueError(
                    "Index 2 of P_w dimension must match input channel partition."
                )
            P_x_new_shape = list(P_x.shape)
            P_x_new_shape.insert(1, 1)
            # Currently a hack, removing the batch dimension because P_w does
            # not have one. This is OK because we assume there are no partitions
            # in the batch dimension.
            P_x_new_shape = np.asarray(P_x_new_shape[1:], dtype=int)

        # For the purposes of this layer, we re-cast P_x to have the extra
        # dimension.  This has no impact outside of the layer or on the results.
        self.P_x = self.P_x.create_cartesian_topology_partition(P_x_new_shape)

        # Ensure that P_y and P_w are correctly aligned.  We also produce a
        # new P_y that is shaped like P_co x 1 x 1 x ... x 1, to assist with
        # broadcasts.
        P_y_new_shape = []
        if self.P_y.active:
            if (np.any(P_y.shape[2:] != P_w_shape[2:])):
                raise ValueError(
                    "Spatial components of P_y and P_w must match.")
            if (np.any(P_y.shape[2:] != np.ones(len(P_y.shape[2:])))):
                raise ValueError(
                    "Spatial components of P_y must be 1 x ... x 1.")
            if P_w_shape[0] != P_y.shape[1]:
                raise ValueError(
                    "Index 1 of P_w dimension must match output channel partition."
                )
            P_y_new_shape = list(P_y.shape)
            P_y_new_shape.insert(2, 1)
            # Currently a hack, removing the batch dimension because P_w does
            # not have one. This is OK because we assume there are no partitions
            # in the batch dimension.
            P_y_new_shape = np.asarray(P_y_new_shape[1:], dtype=int)

        # For the purposes of this layer, we re-cast P_x to have the extra
        # dimension.  This has no impact outside of the layer or on the results.
        self.P_y = self.P_y.create_cartesian_topology_partition(P_y_new_shape)

        self.serial = self.P_w.size == 1

        if self.serial:
            self.conv_layer = self.TorchConvType(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding,
                padding_mode=self.padding_mode,
                dilation=self.dilation,
                groups=self.groups,
                bias=self.use_bias)
            self.weight = self.conv_layer.weight
            self.bias = self.conv_layer.bias
            return

        # Flag if the global bias is set
        self.global_bias = bias

        # Flags if current worker stores (part of) the bias locally.
        self.stores_bias = False

        if self.P_w.active:

            # Let the P_co column store the bias if it is to be used
            self.stores_bias = self.P_w.index[1] == 0 and self.use_bias

            # Correct the input arguments based on local properties
            # This ensures that the in and out channels are correctly shared.
            local_co, local_ci = compute_subshape(P_channels, P_w.index[0:2],
                                                  [out_channels, in_channels])
            self.conv_layer = self.TorchConvType(
                in_channels=local_ci,
                out_channels=local_co,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding,
                padding_mode=self.padding_mode,
                dilation=self.dilation,
                groups=groups,
                bias=self.stores_bias)

        # Workers in P_w alias the conv layer to get their weight and perhaps
        # biases.  Every other worker doesn't have a weight or bias.
        if self.P_w.active:
            self.weight = self.conv_layer.weight
            if self.stores_bias:
                self.bias = self.conv_layer.bias
            else:
                if self.use_bias:
                    self.register_buffer('bias', zero_volume_tensor())
                else:
                    self.register_buffer('bias', None)
        else:
            self.register_buffer('weight', zero_volume_tensor())
            if self.use_bias:
                self.register_buffer('bias', zero_volume_tensor())
            else:
                self.register_buffer('bias', None)

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_tensor_structure = TensorStructure()

        self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True)
        self.y_sum_reduce = SumReduce(self.P_w, self.P_y, preserve_batch=True)
Ejemplo n.º 19
0
def test_repartition_identity(barrier_fence_fixture, comm_split_fixture,
                              P_x_ranks, P_x_shape, x_global_shape, balanced):

    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.repartition import Repartition
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y = P_x

    # The global tensor size is the same for x and y
    layer = Repartition(P_x, P_y, preserve_batch=False)
    layer = layer.to(device)

    # Forward Input
    x = zero_volume_tensor(device=device)
    if P_x.active:
        if balanced:
            x_local_shape = compute_subshape(P_x.shape, P_x.index,
                                             x_global_shape)
        else:
            quotient = np.atleast_1d(x_global_shape) // np.atleast_1d(
                P_x_shape)
            remainder = np.atleast_1d(x_global_shape) % np.atleast_1d(
                P_x_shape)
            loc = np.where(P_x.index == 0)
            x_local_shape = quotient.copy()
            x_local_shape[loc] += remainder[loc]

        x = torch.randn(*x_local_shape, device=device)

    x.requires_grad = True

    # Adjoint Input
    dy = zero_volume_tensor(device=device)
    if P_y.active:
        y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape)
        dy = torch.randn(*y_local_shape, device=device)

    # y = F @ x
    y = layer(x)

    # In the balanced case, this should be a true identity, so there should
    # be no communication performed, just self-copies.
    if balanced:
        for sl, sz, p in layer.P_x_to_y_overlaps:
            assert p == "self" or (sl, sz, p) == (None, None, None)
        for sl, sz, p in layer.P_y_to_x_overlaps:
            assert p == "self" or (sl, sz, p) == (None, None, None)

    # dx = F* @ dy
    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
Ejemplo n.º 20
0
    def __init__(self,
                 P_x,
                 P_y,
                 P_w,
                 in_channels=1,
                 out_channels=1,
                 bias=True,
                 *args,
                 **kwargs):

        super(DistributedGeneralConvBase, self).__init__()

        # P_x is 1    x P_ci x P_d-1 x ... x P_0
        self.P_x = P_x
        # P_y is 1    x P_co x P_d-1 x ... x P_0
        self.P_y = P_y
        # P_w is P_co x P_ci x P_d-1 x ... x P_0
        self.P_w = P_w

        self.P_union = self._distdl_backend.Partition()
        if not (self.P_x.active or self.P_y.active or self.P_w.active):
            return

        # This guarantees that P_union rank 0 has the kernel size, stride,
        # padding, and dilation factors
        P_union = P_w.create_partition_union(P_x)
        P_union = P_union.create_partition_union(P_y)
        self.P_union = P_union

        P_w_shape = None
        if P_union.rank == 0:
            P_w_shape = np.array(P_w.shape, dtype=np.int)
        P_w_shape = P_union.broadcast_data(P_w_shape, root=0)

        P_co = P_w_shape[0]
        P_ci = P_w_shape[1]
        P_channels = [P_co, P_ci]

        P_x_new_shape = []
        if self.P_x.active:
            if (np.any(P_x.shape[2:] != P_w_shape[2:])):
                raise ValueError(
                    "Spatial components of P_x and P_w must match.")
            if P_w_shape[1] != P_x.shape[1]:
                raise ValueError(
                    "Index 2 of P_w dimension must match input channel partition."
                )
            P_x_new_shape = list(P_x.shape)
            P_x_new_shape.insert(1, 1)
            # Currently a hack, removing the batch dimension because P_w does
            # not have one. This is OK because we assume there are no partitions
            # in the batch dimension.
            P_x_new_shape = np.asarray(P_x_new_shape[1:], dtype=int)

        # For the purposes of this layer, we re-cast P_x to have the extra
        # dimension.  This has no impact outside of the layer or on the results.
        self.P_x = self.P_x.create_cartesian_topology_partition(P_x_new_shape)

        P_y_new_shape = []
        if self.P_y.active:
            if (np.any(P_y.shape[2:] != P_w_shape[2:])):
                raise ValueError(
                    "Spatial components of P_y and P_w must match.")
            if P_w_shape[0] != P_y.shape[1]:
                raise ValueError(
                    "Index 1 of P_w dimension must match output channel partition."
                )
            P_y_new_shape = list(P_y.shape)
            P_y_new_shape.insert(2, 1)
            # Currently a hack, removing the batch dimension because P_w does
            # not have one. This is OK because we assume there are no partitions
            # in the batch dimension.
            P_y_new_shape = np.asarray(P_y_new_shape[1:], dtype=int)

        # For the purposes of this layer, we re-cast P_x to have the extra
        # dimension.  This has no impact outside of the layer or on the results.
        self.P_y = self.P_y.create_cartesian_topology_partition(P_y_new_shape)

        P_spatial = P_w_shape[2:]

        self.serial = False
        if self.P_w.size == 1:
            self.serial = True
            self.conv_layer = self.TorchConvType(*args, **kwargs)
            return

        self.receives_weight = False
        self.stores_weight = False
        self.receives_bias = False
        self.stores_bias = False

        # Determine P_r, initialize weights there
        if self.P_w.active:
            # All of P_w always receives the weight
            self.receives_weight = True

            # This subset is taken to be the origin of the spartial component
            w_root_subset = []
            for i, c in enumerate(range_index(P_w.shape)):
                c = np.asarray(c)
                # Find the P_co x P_ci x 1 x ... x 1 subset to store the weights
                if np.all(c[2:] == 0):
                    w_root_subset.append(i)

            self.P_wr_base = self.P_w.create_partition_inclusive(w_root_subset)
            # ones are needed so the broadcast will work
            self.P_wr = self.P_wr_base.create_cartesian_topology_partition(
                [P_co, P_ci] + [1] * len(P_spatial))
            self.stores_weight = self.P_wr.active

            b_subset = []
            for i, c in enumerate(range_index(P_w.shape)):
                c = np.asarray(c)
                # Find the P_co x 1 x P_0 x ... x P_D-1 subset that needs biases in its calculation.
                # This is everywhere that the input channels is rank 0.
                if c[1] == 0:
                    b_subset.append(i)

            self.P_b_base = self.P_w.create_partition_inclusive(b_subset)
            self.P_b = self.P_b_base.create_cartesian_topology_partition(
                [P_co] + [1] + list(P_spatial))
            self.receives_bias = self.P_b.active and bias

            # Now find the subset of _that_ which actually stores the learnable parameter.
            b_root_subset = []
            for i, c in enumerate(range_index(P_w.shape)):
                c = np.asarray(c)
                # Find the P_co x 1 x 1 x ... x 1 subset to store the biases
                if np.all(c[1:] == 0):
                    b_root_subset.append(i)

            self.P_br_base = self.P_w.create_partition_inclusive(b_root_subset)
            # ones are needed so the broadcast will work
            self.P_br = self.P_br_base.create_cartesian_topology_partition(
                [P_co] + [1] + [1] * len(P_spatial))
            self.stores_bias = self.P_br.active and bias

            # Correct the input arguments based on local properties
            local_kwargs = {}
            local_kwargs.update(kwargs)

            # Do this before checking serial so that the layer works properly
            # in the serial case
            local_channels = compute_subshape(P_channels, P_w.index[0:2],
                                              [out_channels, in_channels])
            local_out_channels, local_in_channels = local_channels
            local_kwargs["in_channels"] = local_in_channels
            local_kwargs["out_channels"] = local_out_channels

            local_kwargs["bias"] = self.receives_bias
            self.conv_layer = self.TorchConvType(*args, **local_kwargs)

            # If we store the weight it is a learnable parameter iff it is
            # learnable by default in the layer, which it is.
            if self.stores_weight:
                self._weight = torch.nn.Parameter(
                    self.conv_layer.weight.detach())
            else:
                self._weight = zero_volume_tensor()
            # This always exists so we can copy the property
            self._weight.requires_grad = self.conv_layer.weight.requires_grad

            # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2
            new_weight = self.conv_layer.weight.detach() * 0
            new_weight.requires_grad = self.conv_layer.weight.requires_grad
            del self.conv_layer.weight
            self.conv_layer.weight = new_weight

            # If we store the bias, it is a learnable parameter iff it is
            # learnable by default in the layer, which is only true if it
            # exists.
            if self.stores_bias:
                self._bias = torch.nn.Parameter(self.conv_layer.bias.detach())
            else:
                self._bias = zero_volume_tensor()
            # This does not always exist, but when it does we can copy the
            # property.
            if self.receives_bias:
                self._bias.requires_grad = self.conv_layer.bias.requires_grad

                # https://discuss.pytorch.org/t/assign-parameters-to-nn-module-and-have-grad-fn-track-it/62677/2
                new_bias = self.conv_layer.bias.detach() * 0
                new_bias.requires_grad = self.conv_layer.bias.requires_grad
                del self.conv_layer.bias
                self.conv_layer.bias = new_bias

        # Now we need to share the kernel structure.  The size of the kernel
        # is always the spatial dimensions.
        self.conv_kernel_size = None
        self.conv_stride = None
        self.conv_padding = None
        self.conv_dilation = None
        if P_union.rank == 0:
            self.conv_kernel_size = np.array(self.conv_layer.kernel_size,
                                             dtype=np.int)
            self.conv_stride = np.array(self.conv_layer.stride, dtype=np.int)
            self.conv_padding = np.array(self.conv_layer.padding, dtype=np.int)
            self.conv_dilation = np.array(self.conv_layer.dilation,
                                          dtype=np.int)
        self.conv_kernel_size = P_union.broadcast_data(self.conv_kernel_size,
                                                       root=0)
        self.conv_stride = P_union.broadcast_data(self.conv_stride, root=0)
        self.conv_padding = P_union.broadcast_data(self.conv_padding, root=0)
        self.conv_dilation = P_union.broadcast_data(self.conv_dilation, root=0)

        # We need the halo shape, and other info, to fully populate the pad,
        # halo exchange, and unpad layers.  For pad and unpad, we defer their
        # construction to the pre-forward hook.

        self.pad_layer = None
        self.unpad_layer = None

        # We need to be able to remove some data from the input to the conv
        # layer.
        self.needed_slices = None

        # For the halo layer we also defer construction, so that we can have
        # the halo shape for the input.  The halo will allocate its own
        # buffers, but it needs this information at construction to be able
        # to do this in the pre-forward hook.

        self.halo_layer = None

        # Variables for tracking input changes and buffer construction
        self._distdl_is_setup = False
        self._input_shape = None
        self._input_requires_grad = None

        if P_w.active:
            self.w_broadcast = Broadcast(self.P_wr,
                                         self.P_w,
                                         preserve_batch=False)

        if self.receives_bias or self.stores_bias:
            self.b_broadcast = Broadcast(self.P_br,
                                         self.P_b,
                                         preserve_batch=False)

        self.x_broadcast = Broadcast(self.P_x, self.P_w, preserve_batch=True)
        self.y_sum_reduce = SumReduce(self.P_w, self.P_y, preserve_batch=True)
Ejemplo n.º 21
0
layer = Repartition(P_x, P_y, preserve_batch=False)

# Setup the input tensor.  Any worker in P_x will generate its part of the
# input tensor.  Any worker not in P_x will have a zero-volume tensor.
#
# Input tensor will be (on a 1 x 1 partition):
# [ [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ]
#   [ 1 1 1 1 1 ] ]
x = zero_volume_tensor()
if P_x.active:
    x_local_shape = slicing.compute_subshape(P_x.shape, P_x.index,
                                             x_global_shape)
    x = np.zeros(x_local_shape) + P_x.rank + 1
    x = torch.from_numpy(x)
x.requires_grad = True
print(f"rank {P_world.rank}; index {P_x.index}; value {x}")

# Apply the layer.
#
# Output tensor will be (on a 2 x 2 partition):
# [ [ 1 1 1 | 1 1 ]
#   [ 1 1 1 | 1 1 ]
#   [ 1 1 1 | 1 1 ]
#   [ 1 1 1 | 1 1 ]
#   -------------
#   [ 1 1 1 | 1 1 ]
#   [ 1 1 1 | 1 1 ]
Ejemplo n.º 22
0
    def backward(ctx, grad_output):

        P_union = ctx.P_union
        x_global_shape = ctx.x_global_shape

        P_x = ctx.P_x
        in_data = ctx.in_data
        in_buffers = ctx.in_buffers

        P_y = ctx.P_y
        out_data = ctx.out_data
        out_buffers = ctx.out_buffers

        preserve_batch = ctx.preserve_batch

        dtype = ctx.dtype

        input_requires_grad = ctx.input_requires_grad

        requests = []

        # Default everyone to output None
        if preserve_batch:
            grad_input = zero_volume_tensor(grad_output.shape[0])
        else:
            grad_input = zero_volume_tensor()

        # Recv my input parts
        recv_count = 0
        if P_x.active:
            for (sl, sz, partner), buff in zip(in_data, in_buffers):
                if buff is not None:
                    req = P_union.comm.Irecv(buff, source=partner, tag=113)
                    requests.append(req)
                else:
                    # We add this if there is no recv so that the indices of
                    # the requests array match the indices of in_data and
                    # in_buffers.
                    requests.append(MPI.REQUEST_NULL)
                recv_count += 1

        # Pack and send my input parts
        send_count = 0
        if P_y.active:
            grad_output_numpy = grad_output.detach().numpy()
            for (sl, sz, partner), buff in zip(out_data, out_buffers):
                if buff is not None:
                    np.copyto(buff, grad_output_numpy[tuple(sl)].ravel())
                    req = P_union.comm.Isend(buff, dest=partner, tag=113)
                    requests.append(req)
                else:
                    # We add this for symmetry, but don't really need it.
                    requests.append(MPI.REQUEST_NULL)
                send_count += 1

        if P_x.active:
            index = P_x.index
            x_local_shape = compute_subshape(P_x.shape, index, x_global_shape)
            # TODO(#25): The dtype should not be fixed, but correcting this is
            #            a thing that needs to be resolved globally.
            grad_input = np.zeros(x_local_shape, dtype=dtype)

        # Unpack the received data as it arrives
        completed_count = 0
        while (completed_count < len(requests)):
            status = MPI.Status()
            index = MPI.Request.Waitany(requests, status)

            # In MPI, we don't get the index out if the request is an
            # instance of MPI.REQUEST_NULL, instead MPI.UNDEFINED is returned.
            if P_x.active and index < recv_count and index != MPI.UNDEFINED:
                # Unpack my output parts
                sl, sz, partner = in_data[index]
                buff = in_buffers[index]
                if buff is not None:
                    sh = grad_input[tuple(sl)].shape
                    # This would normally be an add into the grad_input tensor
                    # but we just created it, so a copy is sufficient.
                    np.copyto(grad_input[tuple(sl)], buff.reshape(sh))

            completed_count += 1

        if P_x.active:
            grad_input = torch.from_numpy(grad_input)
            grad_input.requires_grad = input_requires_grad

        return grad_input, None, None, None, None, None, None, None, None, None, None
Ejemplo n.º 23
0
def test_transpose_adjoint(barrier_fence_fixture, comm_split_fixture,
                           P_x_ranks, P_x_shape, P_y_ranks, P_y_shape,
                           x_global_shape, balanced):

    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.transpose import DistributedTranspose
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    device = torch.device('cuda' if use_cuda else 'cpu')

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)

    # Create the partitions
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    P_y_base = P_world.create_partition_inclusive(P_y_ranks)
    P_y = P_y_base.create_cartesian_topology_partition(P_y_shape)

    # The global tensor size is the same for x and y
    layer = DistributedTranspose(P_x, P_y, preserve_batch=False)
    layer = layer.to(device)

    # Forward Input
    x = zero_volume_tensor(device=device)
    if P_x.active:
        if balanced:
            x_local_shape = compute_subshape(P_x.shape, P_x.index,
                                             x_global_shape)
        else:
            quotient = np.atleast_1d(x_global_shape) // np.atleast_1d(
                P_x_shape)
            remainder = np.atleast_1d(x_global_shape) % np.atleast_1d(
                P_x_shape)
            loc = np.where(P_x.index == 0)
            x_local_shape = quotient.copy()
            x_local_shape[loc] += remainder[loc]

        x = torch.randn(*x_local_shape, device=device)

    x.requires_grad = True

    # Adjoint Input
    dy = zero_volume_tensor(device=device)
    if P_y.active:
        y_local_shape = compute_subshape(P_y.shape, P_y.index, x_global_shape)
        dy = torch.randn(*y_local_shape, device=device)

    # y = F @ x
    y = layer(x)

    # dx = F* @ dy
    y.backward(dy)
    dx = x.grad

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)

    P_world.deactivate()
    P_x_base.deactivate()
    P_x.deactivate()
    P_y_base.deactivate()
    P_y.deactivate()
Ejemplo n.º 24
0
    dilation = [1, 1, 1, 1]

    exchange_info = mockup_conv_layer._compute_exchange_info(x_global_shape,
                                                             kernel_size,
                                                             stride,
                                                             padding,
                                                             dilation,
                                                             P_x.active,
                                                             P_x.shape,
                                                             P_x.index)
    halo_shape = exchange_info[0]
    recv_buffer_shape = exchange_info[1]
    send_buffer_shape = exchange_info[2]

    x_local_shape = compute_subshape(P_x.shape,
                                     P_x.index,
                                     x_global_shape)

    value = (1 + rank) * (10 ** rank)
    a = np.full(shape=x_local_shape, fill_value=value, dtype=float)

    forward_input_padnd_layer = PadNd(halo_shape.astype(int), value=0, partition=P_x)
    adjoint_input_padnd_layer = PadNd(halo_shape.astype(int), value=value, partition=P_x)
    t = torch.tensor(a, requires_grad=True)
    t_forward_input = forward_input_padnd_layer.forward(t)
    t_adjoint_input = adjoint_input_padnd_layer.forward(t)

    halo_layer = HaloExchange(P_x, halo_shape, recv_buffer_shape, send_buffer_shape)

    print_sequential(cart_comm, f'rank = {rank}, t_forward_input =\n{t_forward_input.int()}')
Ejemplo n.º 25
0
def test_halo_exchange_adjoint(barrier_fence_fixture,
                               comm_split_fixture,
                               P_x_ranks, P_x_shape,
                               x_global_shape,
                               kernel_size, stride, padding, dilation,
                               MockKernelStyle):
    import numpy as np
    import torch

    from distdl.backends.mpi.partition import MPIPartition
    from distdl.nn.halo_exchange import HaloExchange
    from distdl.nn.padnd import PadNd
    from distdl.utilities.slicing import compute_subshape
    from distdl.utilities.torch import zero_volume_tensor

    # Isolate the minimum needed ranks
    base_comm, active = comm_split_fixture
    if not active:
        return
    P_world = MPIPartition(base_comm)
    P_x_base = P_world.create_partition_inclusive(P_x_ranks)
    P_x = P_x_base.create_cartesian_topology_partition(P_x_shape)

    x_global_shape = np.asarray(x_global_shape)
    kernel_size = np.asarray(kernel_size)
    stride = np.asarray(stride)
    padding = np.asarray(padding)
    dilation = np.asarray(dilation)

    halo_shape = None
    recv_buffer_shape = None
    send_buffer_shape = None
    if P_x.active:
        mockup_layer = MockKernelStyle()
        exchange_info = mockup_layer._compute_exchange_info(x_global_shape,
                                                            kernel_size,
                                                            stride,
                                                            padding,
                                                            dilation,
                                                            P_x.active,
                                                            P_x.shape,
                                                            P_x.index)
        halo_shape = exchange_info[0]
        recv_buffer_shape = exchange_info[1]
        send_buffer_shape = exchange_info[2]

    pad_layer = PadNd(halo_shape, value=0)
    halo_layer = HaloExchange(P_x, halo_shape, recv_buffer_shape, send_buffer_shape)

    x = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        x_local_shape = compute_subshape(P_x.shape,
                                         P_x.index,
                                         x_global_shape)
        x = torch.tensor(np.random.randn(*x_local_shape))
        x = pad_layer.forward(x)
    x.requires_grad = True

    dy = zero_volume_tensor(x_global_shape[0])
    if P_x.active:
        dy = torch.tensor(np.random.randn(*x.shape))

    x_clone = x.clone()
    dy_clone = dy.clone()

    # x_clone is be modified in place by halo_layer, but we assign y to
    # reference it for clarity
    y = halo_layer(x_clone)

    # dy_clone is modified in place by halo_layer-adjoint, but we assign dx to
    # reference it for clarity
    y.backward(dy_clone)
    dx = dy_clone

    x = x.detach()
    dx = dx.detach()
    dy = dy.detach()
    y = y.detach()

    check_adjoint_test_tight(P_world, x, dx, y, dy)