def event_loop_tail_across_minibatches( self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any ) -> None: # handles one epoch cur_rank = self.group.rank() N = len(get_pipeline_parallel_ranks()) num_batches = len(lm_dataloader) lm_iter = enumerate(lm_dataloader) # last partition -> one forward / one backward -> no warmup count = 0 num_gradients = 0 activations = dict() log_interval = 1 word_counter = 0 total_loss = 0 while True: try: start_time = time.time() microbatch_index, cur_batch = next(lm_iter) reqd_target = transform_logger_object.transform_target(cur_batch).to(self.input_device) # one forward message = self.transport.recv_message_header(EVENT_LOOP_ACTIVATIONS_QUEUE) args: AsyncMessageBody = message.args assert args.microbatch_index == count batch = self.get_batch_from_message(message, EVENT_LOOP_GRADIENTS_QUEUE) if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) task = create_task_without_skip_trackers( self.checkpoint_stop, args.microbatch_index, self.group.rank(), batch, self.partitions[0].module, ) output = task.compute() activations[args.microbatch_index] = output task.finalize(output) # one backward if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) output_tensor = transform_logger_object.transform_output_before_loss(output.tensor) loss = criterion(output_tensor, reqd_target) loss.backward() count += 1 num_gradients += 1 if self.perform_optimizer_step(optimizer, num_gradients): optimizer.step() optimizer.zero_grad() transform_logger_object.check_and_save_weights(num_gradients) transform_logger_object.log_loss(cur_batch, loss, count) del loss del activations[args.microbatch_index] except StopIteration: break
def event_loop_across_minibatches(self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any) -> None: activations: Dict[int, Batch] = dict() num_microbatch = len(lm_dataloader) num_activations = 0 num_gradients = 0 ranks = get_pipeline_parallel_ranks() # for warmup phase N = len(ranks) cur_rank = torch.distributed.get_rank() # warmup phase (forward passes) # cur_rank worker will do (max_rank - cur_rank) forward passes n_warmup = ranks[-1] - cur_rank for _ in range(n_warmup): if self.weight_prediction: optimizer.update_weight_using_future_predictions( cur_rank, N, forward=True) # type: ignore message = self.event_loop_trunk_forward_helper(activations) self.transport.send_message(message, sync=True) num_activations += 1 # common loop for remanining items in the warmup phase and steady phase while num_activations < num_microbatch: # 1 Forward if self.weight_prediction: optimizer.update_weight_using_future_predictions( cur_rank, N, forward=True) # type: ignore message = self.event_loop_trunk_forward_helper(activations) num_activations += 1 # 1 Backward if self.weight_prediction: optimizer.update_weight_using_future_predictions( cur_rank, N, forward=False) # type: ignore self.event_loop_trunk_backward_helper(activations) num_gradients += 1 if self.perform_optimizer_step(optimizer, num_gradients): optimizer.step() optimizer.zero_grad() transform_logger_object.check_and_save_weights(num_gradients) self.transport.send_message(message, sync=True) # remaining backwards remaining = len(activations) for _ in range(remaining): if self.weight_prediction: optimizer.update_weight_using_future_predictions( cur_rank, N, forward=False) # type: ignore self.event_loop_trunk_backward_helper(activations) num_gradients += 1 if self.perform_optimizer_step(optimizer, num_gradients): optimizer.step() optimizer.zero_grad() transform_logger_object.check_and_save_weights(num_gradients)
def event_loop_head_across_minibatches( self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any ) -> None: # handles one epoch cur_rank = self.group.rank() N = len(get_pipeline_parallel_ranks()) # for warmup phase activations = dict() count = 0 num_gradients = 0 lm_iter = iter(lm_dataloader) # filling the pipeline: warmup -> all N - 1 forward passes while True: try: cur_batch = next(lm_iter) reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device) batch = Batch(reqd_input, count) if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) activations[count], message = self.async_send_inner(batch, count) self.transport.send_message(message, sync=True) count += 1 if count == N - 1: break except StopIteration: break # steady state while True: try: # 1 forward pass cur_batch = next(lm_iter) reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device) batch = Batch(reqd_input, count) if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) activations[count], forward_message = self.async_send_inner(batch, count) count += 1 # 1 backward pass message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE) args: AsyncMessageBody = message.args assert args.message_type is AsyncMessageType.Gradients if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) self.async_grad_inner(message, activations) # Send after grad self.transport.send_message(forward_message, sync=True) num_gradients += 1 if self.perform_optimizer_step(optimizer, num_gradients): optimizer.step() optimizer.zero_grad() transform_logger_object.check_and_save_weights(num_gradients) except StopIteration: break # remaining items for backward remaining_items = len(activations) for _ in range(remaining_items): message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE) args = message.args assert args.message_type is AsyncMessageType.Gradients if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) self.async_grad_inner(message, activations) num_gradients += 1 if self.perform_optimizer_step(optimizer, num_gradients): optimizer.step() optimizer.zero_grad() transform_logger_object.check_and_save_weights(num_gradients)