Exemplo n.º 1
0
def test_gather_tensors():
    a = torch.zeros(1, 1)
    b = torch.zeros(1, 1)

    ab = gather([Batch(a), Batch(b)])

    assert ab.size() == (2, 1)
Exemplo n.º 2
0
def test_gather_tensors():
    a = torch.zeros(1, 1)
    b = torch.zeros(1, 1)

    ab = gather([a, b], device=torch.device('cpu'))

    assert ab.size() == (2, 1)
Exemplo n.º 3
0
def test_gather_tuples():
    a = (torch.zeros(1, 1), torch.zeros(2, 2))
    b = (torch.zeros(1, 1), torch.zeros(2, 2))

    ab = gather([a, b], device=torch.device('cpu'))

    assert isinstance(ab, tuple)
    assert ab[0].size() == (2, 1)
    assert ab[1].size() == (4, 2)
Exemplo n.º 4
0
def test_gather_tuples():
    a = (torch.zeros(1, 1), torch.zeros(2, 2))
    b = (torch.zeros(1, 1), torch.zeros(2, 2))

    ab = gather([Batch(a), Batch(b)])

    assert isinstance(ab, tuple)
    assert ab[0].size() == (2, 1)
    assert ab[1].size() == (4, 2)
Exemplo n.º 5
0
def test_default_device_index():
    default_cuda = torch.device('cuda')
    assert default_cuda.index is None

    x = torch.rand(2, 1)
    a, b = scatter(x, chunks=2, device=default_cuda)
    y = gather([a, b], device=default_cuda)

    assert a.is_cuda
    assert b.is_cuda
    assert y.is_cuda
Exemplo n.º 6
0
    def forward(self,
                input: TensorOrTensors) -> TensorOrTensors:  # type: ignore
        """:class:`GPipe` is a fairly transparent module wrapper. It doesn't
        modify the input and output signature of the underlying module. But
        there's type restriction. Input and output have to be a
        :class:`~torch.Tensor` or a tuple of tensors. This restriction is
        applied at partition boundaries too.

        Args:
            input (torch.Tensor or tensors): input mini-batch

        Returns:
            tensor or tensors: output mini-batch

        Raises:
            TypeError: input is not a tensor or tensors.

        """
        microbatch.check(input)

        if not self.devices:
            # Empty sequential module is not illegal.
            return input

        # Divide a mini-batch into micro-batches.
        batches = microbatch.scatter(input, self.chunks)

        # Separate CUDA streams for copy.
        copy_streams = self._ensure_copy_streams()

        # The micro-batch index where the checkpointing stops.
        if self.training:
            checkpoint_stop = {
                'always': self.chunks,
                'except_last': self.chunks - 1,
                'never': 0,
            }[self.checkpoint]
        else:
            checkpoint_stop = 0

        # Run pipeline parallelism.
        pipeline = Pipeline(batches, self.partitions, self.devices,
                            copy_streams, self._skip_layout, checkpoint_stop)
        pipeline.run()

        # Merge the micro-batches into one mini-batch.
        output = microbatch.gather(batches)
        return output
Exemplo n.º 7
0
    def _pull_output(
        self,
        num_inputs: int,
        in_queue: PriorityQueue,
        out_queue: PriorityQueue,
    ) -> Tensor:
        """Collects and concatenates chunked outputs from the last partition.

        If an exception from a parititon is detected, all workers are closed
        and the exception is re-raised.

        Raises:
            Exception: any exception from a partition

        """
        # All worker threads will be closed when receiving this message.
        close = Message(-1, None)

        outputs = []
        for _ in range(num_inputs):
            msg = out_queue.get()
            out_queue.task_done()

            if msg.i == -1:
                # Close worker threads immediately.
                in_queue.put(close)
                out_queue.get()

                # Raise the exception from a partition.
                exc_info = msg.payload
                raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

            output, _ = msg.payload
            outputs.append(output)

        # Indicate the end of micro-batches.
        in_queue.put(close)

        # Merge the micro-batches into one mini-batch.
        out_device = self.devices[-1]
        output = gather(outputs, device=out_device)

        # Wait until the last worker thread closes.
        out_queue.get()

        return output
Exemplo n.º 8
0
    def _pull_output(
        self,
        num_inputs: int,
        in_queue: PriorityQueue,
        out_queue: PriorityQueue,
    ) -> Tensor:
        """Collects and concatenates chunked outputs from the last partition.

        If an exception from a parititon is detected, all workers are closed
        and the exception is re-raised.

        Raises:
            Exception: any exception from a partition

        """
        outputs = []
        for _ in range(num_inputs):
            msg = out_queue.get()
            out_queue.task_done()

            if msg.i == -1:
                # Close worker threads immediately.
                close = Message(-1, None)
                in_queue.put(close)
                out_queue.get()

                # Raise the exception from a partition.
                exc_info = msg.payload
                raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

            output, _, _ = msg.payload
            outputs.append(output)

        out_device = self.devices[-1]
        output = gather(outputs, device=out_device)
        out_queue.get()

        return output