def make_checkpoint(function: Function, input: TensorOrTensors, index: int) -> TensorOrTensors: """Makes a checkpoint with a simple interface like :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug :class:`Checkpoint` and :class:`Recompute` without boilerplate. """ batch = Batch(input, index) chk = Checkpointing(function, batch) batch = chk.checkpoint() chk.recompute(batch) return batch.tensor_or_tensors
def test_not_requires_grad(): x = Batch(torch.rand(1, requires_grad=False)) assert not x[0].requires_grad def f(x): return x * 2 chk = Checkpointing(f, x) x = chk.checkpoint() assert x[0].requires_grad chk.recompute(x) assert x[0].requires_grad x.tensor.backward()
def create_task_without_skip_trackers( checkpoint_stop: int, i: int, j: int, batch: Batch, partition: nn.Sequential, ) -> Task: # Determine whether checkpointing or not. if i < checkpoint_stop: def function( input: TensorOrTensors, partition: nn.Sequential = partition, chunk_id: int = i, part_id: int = j, ) -> TensorOrTensors: with record_function("chunk%d-part%d" % (chunk_id, part_id)): return partition(input) chk = Checkpointing(function, batch) task = Task(None, compute=chk.checkpoint, finalize=chk.recompute) del function, chk else: def compute( batch: Batch = batch, partition: nn.Sequential = partition, chunk_id: int = i, part_id: int = j, ) -> Batch: with record_function("chunk%d-part%d" % (chunk_id, part_id)): return batch.call(partition) task = Task(None, compute=compute, finalize=None) del compute return task
def compute(self, pipeline_record: DistributedPipelineRecord, chunk: int) -> None: """Runs tasks with synchronization to tensor-pipe streams.""" checkpoint_stop = self.checkpoint_stop # Disable checkpointing if in eval mode. if not self.module.training: checkpoint_stop = 0 exc_info: Optional[ExcInfo] = None batch = pipeline_record.get_batch(chunk) if is_cuda(self.stream): pipeline_record.sync_stream(chunk, as_cuda(self.stream)) # Determine whether checkpointing or not. checkpoint = chunk < checkpoint_stop if checkpoint: def function(input: TensorOrTensors, chunk_id: int = chunk) -> TensorOrTensors: with record_function("chunk%d-rank%d" % (chunk_id, pipeline_record.rank)): result = self.module(*input) if self.num_outputs is None: result = (result, ) return tuple(result) chk = Checkpointing(function, batch) task = Task(self.stream, compute=chk.checkpoint, finalize=chk.recompute) del function, chk else: def compute( batch: Batch = batch, chunk_id: int = chunk, rank: int = pipeline_record.rank if pipeline_record is not None else -1, ) -> Batch: with record_function("chunk%d-rank%d" % (chunk_id, pipeline_record.rank)): result = self.module(*batch.tensors) if self.num_outputs is None: result = (result, ) return Batch(result, chunk_id) task = Task(self.stream, compute=compute, finalize=None) del compute self.in_queue.put(task) ok, payload = self.out_queue.get() # Hold the first exception. if exc_info is not None: pass elif not ok: exc_info = cast(ExcInfo, payload) else: task, batch = cast(Tuple[Task, Batch], payload) with use_device(self.device): task.finalize(batch) pipeline_record.batches[chunk] = batch if exc_info is not None: raise exc_info[0].with_traceback(exc_info[1], exc_info[2])