示例#1
0
def test_scatter_tensor():
    ab = torch.zeros(2, 1)

    a, b = scatter(ab, chunks=2, device=torch.device('cpu'))

    assert a.size() == (1, 1)
    assert b.size() == (1, 1)
示例#2
0
    def push_input(self,
                   input: TensorOrTensors,
                   in_queue: PriorityQueue,
                   ) -> int:
        """Pushes chunked inputs to the first partition."""
        # Divide a mini-batch into micro-batches.
        inputs = scatter(input, chunks=self.chunks, device=self.in_device)

        # The number of inputs might be smaller than the number of chunks.
        num_inputs = len(inputs)

        for i, _input in enumerate(inputs):
            # NOTE(sublee): 'except_last' is the defualt option. Compare it first.
            if self.checkpoint == 'except_last':
                checkpoint = (i < self.chunks-1)
            elif self.checkpoint == 'always':
                checkpoint = True
            elif self.checkpoint == 'never':
                checkpoint = False

            msg = Message(i, (_input, checkpoint))
            in_queue.put(msg)

        close = Message(num_inputs, None)
        in_queue.put(close)

        return num_inputs
示例#3
0
    def _push_input(
        self,
        input: TensorOrTensors,
        in_queue: PriorityQueue,
    ) -> int:
        """Pushes chunked inputs to the first partition."""
        # Divide a mini-batch into micro-batches.
        in_device = self.devices[0]
        inputs = scatter(input, chunks=self.chunks, device=in_device)

        # The number of inputs might be smaller than the number of chunks.
        num_inputs = len(inputs)

        for i, _input in enumerate(inputs):
            # NOTE(sublee): 'except_last' is the defualt option. Compare it first.
            if self.checkpoint == 'except_last':
                checkpoint = (i < self.chunks - 1)
            elif self.checkpoint == 'always':
                checkpoint = True
            elif self.checkpoint == 'never':
                checkpoint = False

            # Every partition should track the current micro-batch. A
            # micro-batch lane can be identified its detached leaf tensor.
            leaf = (_input[0]
                    if isinstance(_input, tuple) else _input).detach()

            msg = Message(i, (_input, leaf, checkpoint))
            in_queue.put(msg)

        close = Message(num_inputs, None)
        in_queue.put(close)

        return num_inputs
示例#4
0
def test_scatter_tensor():
    ab = torch.zeros(2, 1)

    a, b = scatter(ab, chunks=2)

    assert a.tensor.size() == (1, 1)
    assert b.tensor.size() == (1, 1)
示例#5
0
def test_scatter_tuple():
    ab = (torch.zeros(2, 1), torch.zeros(4, 2))

    a, b = scatter(ab, chunks=2)

    assert a.tensors[0].size() == (1, 1)
    assert b.tensors[0].size() == (1, 1)
    assert a.tensors[1].size() == (2, 2)
    assert b.tensors[1].size() == (2, 2)
示例#6
0
def test_scatter_tuple():
    ab = (torch.zeros(2, 1), torch.zeros(4, 2))

    a, b = scatter(ab, chunks=2, device=torch.device('cpu'))

    assert isinstance(a, tuple)
    assert isinstance(b, tuple)
    assert a[0].size() == (1, 1)
    assert b[0].size() == (1, 1)
    assert a[1].size() == (2, 2)
    assert b[1].size() == (2, 2)
示例#7
0
def test_default_device_index():
    default_cuda = torch.device('cuda')
    assert default_cuda.index is None

    x = torch.rand(2, 1)
    a, b = scatter(x, chunks=2, device=default_cuda)
    y = gather([a, b], device=default_cuda)

    assert a.is_cuda
    assert b.is_cuda
    assert y.is_cuda
示例#8
0
    def forward(self,
                input: TensorOrTensors) -> TensorOrTensors:  # type: ignore
        """:class:`GPipe` is a fairly transparent module wrapper. It doesn't
        modify the input and output signature of the underlying module. But
        there's type restriction. Input and output have to be a
        :class:`~torch.Tensor` or a tuple of tensors. This restriction is
        applied at partition boundaries too.

        Args:
            input (torch.Tensor or tensors): input mini-batch

        Returns:
            tensor or tensors: output mini-batch

        Raises:
            TypeError: input is not a tensor or tensors.

        """
        microbatch.check(input)

        if not self.devices:
            # Empty sequential module is not illegal.
            return input

        # Divide a mini-batch into micro-batches.
        batches = microbatch.scatter(input, self.chunks)

        # Separate CUDA streams for copy.
        copy_streams = self._ensure_copy_streams()

        # The micro-batch index where the checkpointing stops.
        if self.training:
            checkpoint_stop = {
                'always': self.chunks,
                'except_last': self.chunks - 1,
                'never': 0,
            }[self.checkpoint]
        else:
            checkpoint_stop = 0

        # Run pipeline parallelism.
        pipeline = Pipeline(batches, self.partitions, self.devices,
                            copy_streams, self._skip_layout, checkpoint_stop)
        pipeline.run()

        # Merge the micro-batches into one mini-batch.
        output = microbatch.gather(batches)
        return output