def test_gather_tensors(): a = torch.zeros(1, 1) b = torch.zeros(1, 1) ab = gather([Batch(a), Batch(b)]) assert ab.size() == (2, 1)
def test_gather_tuples(): a = (torch.zeros(1, 1), torch.zeros(2, 2)) b = (torch.zeros(1, 1), torch.zeros(2, 2)) ab = gather([Batch(a), Batch(b)]) assert isinstance(ab, tuple) assert ab[0].size() == (2, 1) assert ab[1].size() == (4, 2)
def get_batch_from_message(self, message: PipeMessage, queue_name: int) -> Batch: """Get the tensor(s) wrapped in a `Batch` from a `PipeMessage`, applying AsyncRecvOperator so we can intercept the backward pass""" microbatch_index = message.args.microbatch_index phony = torch.empty(0, device=self.transport.input_device, requires_grad=True) result = AsyncRecvOperator.apply(phony, self.transport, message, queue_name) if len(result) == 1: batch = Batch(result[0], microbatch_index) else: batch = Batch(result, microbatch_index) return batch
def test_batch_setitem_by_slice(): a = Batch(torch.tensor(42)) b = Batch((torch.tensor(42), torch.tensor(21))) a[:] = (torch.tensor(0), ) b[:] = (torch.tensor(0), ) assert a.atomic assert a[0].item() == 0 assert not b.atomic assert len(b) == 1 assert b[0].item() == 0
def test_batch_setitem_by_index(): a = Batch(torch.tensor(42)) b = Batch((torch.tensor(42), torch.tensor(21))) a[0] = torch.tensor(0) b[0] = torch.tensor(0) assert a.atomic assert a[0].item() == 0 assert not b.atomic assert len(b) == 2 assert b[0].item() == 0 assert b[1].item() == 21
def compute( batch: Batch = batch, partition: nn.Sequential = partition, chunk_id: int = i, part_id: int = j, ) -> Batch: with record_function("chunk%d-part%d" % (chunk_id, part_id)): return batch.call(partition)
def wait_for(self, chunk: int) -> None: """Waits until all elements of given chunk is populated in self.tensors. Then it constructs self.batches[chunk] if it is not constructed yet. """ with self.ready_cv: while self.batches[chunk] is None and any( b is None for b in self.tensors[chunk]): self.ready_cv.wait() if self.batches[chunk] is None: tensors = cast(List[Tensor], self.tensors[chunk]) self.batches[chunk] = Batch(tuple(tensors), chunk)
def make_checkpoint(function: Function, input: TensorOrTensors, index: int) -> TensorOrTensors: """Makes a checkpoint with a simple interface like :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug :class:`Checkpoint` and :class:`Recompute` without boilerplate. """ batch = Batch(input, index) chk = Checkpointing(function, batch) batch = chk.checkpoint() chk.recompute(batch) return batch.tensor_or_tensors
def compute( batch: Batch = batch, chunk_id: int = chunk, rank: int = pipeline_record.rank if pipeline_record is not None else -1, ) -> Batch: with record_function("chunk%d-rank%d" % (chunk_id, pipeline_record.rank)): result = self.module(*batch.tensors) if self.num_outputs is None: result = (result, ) return Batch(result, chunk_id)
def test_batch_atomic(): x = torch.tensor(42) b = Batch(x) assert b.atomic assert b.tensor is x with pytest.raises(AttributeError): b.tensors assert list(b) == [x] assert len(b) == 1 assert b[0] is x
def test_tensor_life_without_checkpointing(): skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) skip_tracker = SkipTrackerThroughPotals(skip_layout) batch = Batch(torch.tensor([1.0])) tensor = torch.tensor([2.0]) skip_tracker.save(batch, None, "test", tensor) assert skip_tracker.portals[(None, "test")].tensor_life == 1 skip_tracker.load(batch, None, "test") assert skip_tracker.portals[(None, "test")].tensor_life == 0
def test_reuse_portal(): skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) skip_tracker = SkipTrackerThroughPotals(skip_layout) batch = Batch(torch.tensor([1.0])) a = torch.tensor([2.0]) b = torch.tensor([2.0]) skip_tracker.save(batch, None, "test", a) portal = skip_tracker.portals[(None, "test")] skip_tracker.save(batch, None, "test", b) assert portal is skip_tracker.portals[(None, "test")]
def test_batch_non_atomic(): x, y = torch.tensor(42), torch.tensor(21) b = Batch((x, y)) assert not b.atomic with pytest.raises(AttributeError): b.tensor assert b.tensors == (x, y) assert list(b) == [x, y] assert len(b) == 2 assert b[0] is x assert b[1] is y
def test_batch_call(): a = Batch(torch.tensor(42)) b = Batch((torch.tensor(42), torch.tensor(21))) def f(x): return x assert a.call(f).atomic assert not b.call(f).atomic
def test_not_requires_grad(): x = Batch(torch.rand(1, requires_grad=False)) assert not x[0].requires_grad def f(x): return x * 2 chk = Checkpointing(f, x) x = chk.checkpoint() assert x[0].requires_grad chk.recompute(x) assert x[0].requires_grad x.tensor.backward()
def fence(self, chunk: int) -> None: """Prepares micro-batches for computation.""" # Ensure that batches[chunk-1] is executed after batches[chunk] in # backpropagation by an explicit dependency. # TODO: This dependency injection causes deadlock if this partition # gets its input from model input. 1) Figure out why 2) If we need to live # with this constraint, replace the condition 'self.rank > 0' below with # a more accurate one. if chunk != 0 and self.consumers and self.rank > 0: batch = self.batches[chunk] assert batch is not None dependant_tensors = list(batch.tensors) for remote_ph_list in self.forwarded_phony[chunk - 1]: for remote_ph in remote_ph_list: phony = remote_ph.to_here() dependant_tensors[0] = join(dependant_tensors[0], phony) self.batches[chunk] = Batch(tuple(dependant_tensors), chunk)
def test_no_copy_no_portal(): skip_layout = SkipLayout(num_partitions=2, skip_routes={ (None, "copy"): (0, 1), (None, "not_copy"): (0, 0) }) skip_tracker = SkipTrackerThroughPotals(skip_layout) batch = Batch(torch.tensor([1.0])) a = torch.tensor([2.0]) b = torch.tensor([2.0]) skip_tracker.save(batch, None, "copy", a) skip_tracker.save(batch, None, "not_copy", b) assert (None, "copy") in skip_tracker.portals assert (None, "copy") not in skip_tracker.tensors assert (None, "not_copy") in skip_tracker.tensors assert (None, "not_copy") not in skip_tracker.portals
def event_loop_head_across_minibatches( self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any ) -> None: # handles one epoch cur_rank = self.group.rank() N = len(get_pipeline_parallel_ranks()) # for warmup phase activations = dict() count = 0 num_gradients = 0 lm_iter = iter(lm_dataloader) # filling the pipeline: warmup -> all N - 1 forward passes while True: try: cur_batch = next(lm_iter) reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device) batch = Batch(reqd_input, count) if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) activations[count], message = self.async_send_inner(batch, count) self.transport.send_message(message, sync=True) count += 1 if count == N - 1: break except StopIteration: break # steady state while True: try: # 1 forward pass cur_batch = next(lm_iter) reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device) batch = Batch(reqd_input, count) if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) activations[count], forward_message = self.async_send_inner(batch, count) count += 1 # 1 backward pass message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE) args: AsyncMessageBody = message.args assert args.message_type is AsyncMessageType.Gradients if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) self.async_grad_inner(message, activations) # Send after grad self.transport.send_message(forward_message, sync=True) num_gradients += 1 if self.perform_optimizer_step(optimizer, num_gradients): optimizer.step() optimizer.zero_grad() transform_logger_object.check_and_save_weights(num_gradients) except StopIteration: break # remaining items for backward remaining_items = len(activations) for _ in range(remaining_items): message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE) args = message.args assert args.message_type is AsyncMessageType.Gradients if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) self.async_grad_inner(message, activations) num_gradients += 1 if self.perform_optimizer_step(optimizer, num_gradients): optimizer.step() optimizer.zero_grad() transform_logger_object.check_and_save_weights(num_gradients)
def _42(): return Batch(torch.tensor(42), 0)
def detect_grad_enabled(): x = torch.rand(1, requires_grad=torch.is_grad_enabled()) return Batch(x, 0)
def counter(): nonlocal count time.sleep(0.1) count += 1 return Batch((), 0)
def log_thread_id(): thread_id = threading.current_thread().ident thread_ids.add(thread_id) return Batch((), 0)