Exemplo n.º 1
0
    def forward(self, input: TensorOrTensors) -> TensorOrTensors:
        """:class:`GPipe` is a fairly transparent module wrapper. It doesn't
        modify the input and output signature of the underlying module. But
        there's type restriction. Input and output have to be a
        :class:`~torch.Tensor` or a tuple of tensors. This restriction is
        applied at partition boundaries too.

        Args:
            input (tensor or tensors): input mini-batch

        Returns:
            tensor or tensors: output mini-batch

        Raises:
            TypeError: input is not a tensor or tensors.

        """
        if not self.devices:
            # An empty sequential module is wrapped. Empty sequential module is
            # not illegal. Just check the input type.
            check(input)
            return input

        in_queue, out_queue = self._spawn_workers()
        num_inputs = self._push_input(input, in_queue)
        return self._pull_output(num_inputs, in_queue, out_queue)
Exemplo n.º 2
0
def test_check():
    check(torch.tensor(42))
    check((torch.tensor(4), torch.tensor(2)))

    with pytest.raises(TypeError):
        check(42)

    with pytest.raises(TypeError):
        check('str')

    with pytest.raises(TypeError):
        check((torch.tensor(4), 2))
Exemplo n.º 3
0
    def forward(self,
                input: TensorOrTensors) -> TensorOrTensors:  # type: ignore
        """:class:`GPipe` is a fairly transparent module wrapper. It doesn't
        modify the input and output signature of the underlying module. But
        there's type restriction. Input and output have to be a
        :class:`~torch.Tensor` or a tuple of tensors. This restriction is
        applied at partition boundaries too.

        Args:
            input (torch.Tensor or tensors): input mini-batch

        Returns:
            tensor or tensors: output mini-batch

        Raises:
            TypeError: input is not a tensor or tensors.

        """
        microbatch.check(input)

        if not self.devices:
            # Empty sequential module is not illegal.
            return input

        # Divide a mini-batch into micro-batches.
        batches = microbatch.scatter(input, self.chunks)

        # Separate CUDA streams for copy.
        copy_streams = self._ensure_copy_streams()

        # The micro-batch index where the checkpointing stops.
        if self.training:
            checkpoint_stop = {
                'always': self.chunks,
                'except_last': self.chunks - 1,
                'never': 0,
            }[self.checkpoint]
        else:
            checkpoint_stop = 0

        # Run pipeline parallelism.
        pipeline = Pipeline(batches, self.partitions, self.devices,
                            copy_streams, self._skip_layout, checkpoint_stop)
        pipeline.run()

        # Merge the micro-batches into one mini-batch.
        output = microbatch.gather(batches)
        return output