def __init__( self, module_rref: rpc.RRef, device: str, num_inputs: int, num_outputs: Optional[int], rank: int, chunks: int, checkpoint_stop: int, ) -> None: self.module = module_rref.local_value() self.chunks = chunks self.device = torch.device(device) self.checkpoint_stop = checkpoint_stop self.rank = rank self.num_inputs = num_inputs self.num_outputs = num_outputs (self.in_queue, ), (self.out_queue, ) = create_workers([self.device])
def run_pipeline(self, pipeline_record_rref: rpc.RRef) -> Optional[Tensor]: """Processes a min-batch on this partition. If this is the last partition (pipeline_record has no consumer), concatenates results of processing all chunks and returns the result as the output of the model on the whole mini-batch. """ pipeline_record = pipeline_record_rref.local_value() self.run(pipeline_record) if not pipeline_record.consumers: result = microbatch.gather(pipeline_record.batches) assert len(result) == 1 result = result[0] s0 = current_stream(result.device) if is_cuda(s0): # TODO. Investigate why this is needed and remove it if possible. as_cuda(s0).synchronize() return result return None
def test_local_rref_no_fork(self): local_rref = RRef(35) self.assertEqual(local_rref.local_value(), 35)
def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]: return [rpc.RRef(p) for p in module.local_value().parameters()]
def _rcheckpoint(rmodule: rpc.RRef, input_rref: rpc.RRef) -> TensorOrTensors: module = rmodule.local_value() input = module[0](input_rref) # calls _ToHere.forward return checkpoint_sequential(module[1:], 1, input)