Exemple #1
0
    def _test_wait_stream(self, source, target, cuda_sleep=None):
        with use_stream(target):
            if is_cuda(target):
                cuda_sleep(0.5)
            x = torch.ones(100, 100, device=get_device(target))

        wait_stream(source, target)

        with use_stream(source):
            assert x.sum().item() == 10000
Exemple #2
0
def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None):
    device = get_device(prev_stream)

    with use_stream(prev_stream):
        if is_cuda(prev_stream):
            cuda_sleep(0.5)
        x = torch.ones(100, device=device, requires_grad=True)

    (y, ) = Copy.apply(prev_stream, next_stream, x)
    (y, ) = Wait.apply(prev_stream, next_stream, x)

    with use_stream(next_stream):
        assert torch.allclose(y.sum(), torch.tensor(100.0, device=device))
        y.norm().backward()
    with use_stream(prev_stream):
        assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device))
    def run_pipeline(self, pipeline_record_rref: rpc.RRef) -> Optional[Tensor]:
        """Processes a min-batch on this partition.
           If this is the last partition (pipeline_record has no consumer), concatenates results of processing
           all chunks and returns the result as the output of the model on the whole mini-batch.
        """
        pipeline_record = pipeline_record_rref.local_value()
        self.run(pipeline_record)

        if not pipeline_record.consumers:
            result = microbatch.gather(pipeline_record.batches)
            assert len(result) == 1
            result = result[0]
            s0 = current_stream(result.device)
            if is_cuda(s0):
                # TODO. Investigate why this is needed and remove it if possible.
                as_cuda(s0).synchronize()
            return result

        return None
    def compute(self, pipeline_record: DistributedPipelineRecord,
                chunk: int) -> None:
        """Runs tasks with synchronization to tensor-pipe streams."""
        checkpoint_stop = self.checkpoint_stop

        # Disable checkpointing if in eval mode.
        if not self.module.training:
            checkpoint_stop = 0

        exc_info: Optional[ExcInfo] = None

        batch = pipeline_record.get_batch(chunk)

        if is_cuda(self.stream):
            pipeline_record.sync_stream(chunk, as_cuda(self.stream))

        # Determine whether checkpointing or not.
        checkpoint = chunk < checkpoint_stop
        if checkpoint:

            def function(input: TensorOrTensors,
                         chunk_id: int = chunk) -> TensorOrTensors:
                with record_function("chunk%d-rank%d" %
                                     (chunk_id, pipeline_record.rank)):
                    result = self.module(*input)
                    if self.num_outputs is None:
                        result = (result, )
                    return tuple(result)

            chk = Checkpointing(function, batch)
            task = Task(self.stream,
                        compute=chk.checkpoint,
                        finalize=chk.recompute)
            del function, chk

        else:

            def compute(
                batch: Batch = batch,
                chunk_id: int = chunk,
                rank: int = pipeline_record.rank
                if pipeline_record is not None else -1,
            ) -> Batch:
                with record_function("chunk%d-rank%d" %
                                     (chunk_id, pipeline_record.rank)):
                    result = self.module(*batch.tensors)
                    if self.num_outputs is None:
                        result = (result, )
                return Batch(result, chunk_id)

            task = Task(self.stream, compute=compute, finalize=None)
            del compute

        self.in_queue.put(task)

        ok, payload = self.out_queue.get()

        # Hold the first exception.
        if exc_info is not None:
            pass
        elif not ok:
            exc_info = cast(ExcInfo, payload)
        else:
            task, batch = cast(Tuple[Task, Batch], payload)

            with use_device(self.device):
                task.finalize(batch)

            pipeline_record.batches[chunk] = batch

        if exc_info is not None:
            raise exc_info[0].with_traceback(exc_info[1], exc_info[2])