def send_skip_tensors( self, batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> None: ranks = get_pipeline_parallel_ranks() this_rank = torch.distributed.get_rank() for next_j, ns, name in self.skip_layout.copy_policy_by_src( self.group.rank()): life = skip_trackers[i].portals[(ns, name)].tensor_life loaded = skip_trackers[i].load(batch, ns, name) if loaded is not None: tensors = tuple([loaded]) else: tensors = tuple() self.transport.send_message( PipeMessage( this_rank, ranks[next_j], queue_name=SKIP_TENSOR_QUEUE, args=(i, ns, name, life), tensors=tensors, ), sync=True, )
def run_invocation( self, batch: Batch, partition: ModuleWrapper, skip_trackers: List[SkipTrackerThroughPotals], invocation: Invocation, ) -> Batch: """Actually run the forward pass for a given module, and send the result to the next stage in the pipeline if needed.""" task = create_task( self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers, ) result = task.compute() task.finalize(result) if invocation.dest and invocation.dest.stage != invocation.this.stage: ranks = get_pipeline_parallel_ranks() dst_rank = ranks[invocation.dest.stage] result = self.send_async_message(dst_rank, result, invocation) return result
def async_send_inner(self, batch: Batch, index: int) -> Tuple[Batch, PipeMessage]: task = create_task_without_skip_trackers( self.checkpoint_stop, index, self.group.rank(), batch, self.partitions[0].module, ) result = task.compute() task.finalize(result) ranks = get_pipeline_parallel_ranks() this_rank = torch.distributed.get_rank() body = AsyncMessageBody( AsyncMessageType.Activations, index, Location(this_rank, 0), Location(ranks[ranks.index(this_rank) + 1], 0), 0, ) message = PipeMessage( this_rank, ranks[ranks.index(this_rank) + 1], queue_name=EVENT_LOOP_ACTIVATIONS_QUEUE, args=body, tensors=tuple([*result]), ) return result, message
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]: ranks = get_pipeline_parallel_ranks() this_rank = torch.distributed.get_rank() body = AsyncMessageBody( AsyncMessageType.Gradients, ctx.index, source=ctx.args.dest, dest=ctx.args.source, order=ctx.args.order - 1 ) ctx.transport.send_message( PipeMessage( this_rank, ranks[ctx.args.source.stage], queue_name=ctx.queue_name, args=body, tensors=tuple(grad), ), sync=True, ) tail_ctx = getattr(ctx, "tail_ctx", None) if tail_ctx: expected_gradients = tail_ctx.expected_gradients while expected_gradients > 0: message = ctx.transport.recv_message_header(ctx.queue_name) args: AsyncMessageBody = message.args assert args.message_type is AsyncMessageType.Gradients invocation = tail_ctx.invocations[args.order] expected_gradients -= tail_ctx.count_per_order[invocation.order] AsyncEventLoop.perform_backward_for_invocation(ctx.transport, message, tail_ctx.activations, invocation) return (None, None, None, None, None)
def event_loop_tail_across_minibatches( self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any ) -> None: # handles one epoch cur_rank = self.group.rank() N = len(get_pipeline_parallel_ranks()) num_batches = len(lm_dataloader) lm_iter = enumerate(lm_dataloader) # last partition -> one forward / one backward -> no warmup count = 0 num_gradients = 0 activations = dict() log_interval = 1 word_counter = 0 total_loss = 0 while True: try: start_time = time.time() microbatch_index, cur_batch = next(lm_iter) reqd_target = transform_logger_object.transform_target(cur_batch).to(self.input_device) # one forward message = self.transport.recv_message_header(EVENT_LOOP_ACTIVATIONS_QUEUE) args: AsyncMessageBody = message.args assert args.microbatch_index == count batch = self.get_batch_from_message(message, EVENT_LOOP_GRADIENTS_QUEUE) if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) task = create_task_without_skip_trackers( self.checkpoint_stop, args.microbatch_index, self.group.rank(), batch, self.partitions[0].module, ) output = task.compute() activations[args.microbatch_index] = output task.finalize(output) # one backward if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) output_tensor = transform_logger_object.transform_output_before_loss(output.tensor) loss = criterion(output_tensor, reqd_target) loss.backward() count += 1 num_gradients += 1 if self.perform_optimizer_step(optimizer, num_gradients): optimizer.step() optimizer.zero_grad() transform_logger_object.check_and_save_weights(num_gradients) transform_logger_object.log_loss(cur_batch, loss, count) del loss del activations[args.microbatch_index] except StopIteration: break
def event_loop_across_minibatches(self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any) -> None: activations: Dict[int, Batch] = dict() num_microbatch = len(lm_dataloader) num_activations = 0 num_gradients = 0 ranks = get_pipeline_parallel_ranks() # for warmup phase N = len(ranks) cur_rank = torch.distributed.get_rank() # warmup phase (forward passes) # cur_rank worker will do (max_rank - cur_rank) forward passes n_warmup = ranks[-1] - cur_rank for _ in range(n_warmup): if self.weight_prediction: optimizer.update_weight_using_future_predictions( cur_rank, N, forward=True) # type: ignore message = self.event_loop_trunk_forward_helper(activations) self.transport.send_message(message, sync=True) num_activations += 1 # common loop for remanining items in the warmup phase and steady phase while num_activations < num_microbatch: # 1 Forward if self.weight_prediction: optimizer.update_weight_using_future_predictions( cur_rank, N, forward=True) # type: ignore message = self.event_loop_trunk_forward_helper(activations) num_activations += 1 # 1 Backward if self.weight_prediction: optimizer.update_weight_using_future_predictions( cur_rank, N, forward=False) # type: ignore self.event_loop_trunk_backward_helper(activations) num_gradients += 1 if self.perform_optimizer_step(optimizer, num_gradients): optimizer.step() optimizer.zero_grad() transform_logger_object.check_and_save_weights(num_gradients) self.transport.send_message(message, sync=True) # remaining backwards remaining = len(activations) for _ in range(remaining): if self.weight_prediction: optimizer.update_weight_using_future_predictions( cur_rank, N, forward=False) # type: ignore self.event_loop_trunk_backward_helper(activations) num_gradients += 1 if self.perform_optimizer_step(optimizer, num_gradients): optimizer.step() optimizer.zero_grad() transform_logger_object.check_and_save_weights(num_gradients)
def forward(ctx, transport: Transport, input: List[Tensor], index: int) -> Tensors: ranks = get_pipeline_parallel_ranks() src_rank = torch.distributed.get_rank() dst_rank = ranks[ranks.index(src_rank) + 1] transport.send_message( PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)), ) return ()
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]: ranks = get_pipeline_parallel_ranks() this_rank = torch.distributed.get_rank() ctx.transport.send_message( PipeMessage( this_rank, ranks[ranks.index(this_rank) - 1], queue_name=ACTIVATIONS_GRADS_QUEUE, args=ctx.index, tensors=tuple(grad), ), ) return (None, None, None, None, None)
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None: dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1)) if dest == src: return ranks = get_pipeline_parallel_ranks() dst_rank = ranks[dest] if dst_rank == torch.distributed.get_rank(): return if isinstance(grad, Tensor): grad = tuple([grad]) self.transport.send_message( PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), sync=True, )
def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch: batch = task.compute() rank = self.group.rank() if not self.final_stage: ranks = get_pipeline_parallel_ranks() this_rank = torch.distributed.get_rank() self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers) SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport, [*batch], i) for portal in skip_trackers[i].portals.values(): portal.pipeline = self task.finalize(batch) return batch
def event_loop_head_across_minibatches( self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any ) -> None: # handles one epoch cur_rank = self.group.rank() N = len(get_pipeline_parallel_ranks()) # for warmup phase activations = dict() count = 0 num_gradients = 0 lm_iter = iter(lm_dataloader) # filling the pipeline: warmup -> all N - 1 forward passes while True: try: cur_batch = next(lm_iter) reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device) batch = Batch(reqd_input, count) if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) activations[count], message = self.async_send_inner(batch, count) self.transport.send_message(message, sync=True) count += 1 if count == N - 1: break except StopIteration: break # steady state while True: try: # 1 forward pass cur_batch = next(lm_iter) reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device) batch = Batch(reqd_input, count) if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) activations[count], forward_message = self.async_send_inner(batch, count) count += 1 # 1 backward pass message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE) args: AsyncMessageBody = message.args assert args.message_type is AsyncMessageType.Gradients if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) self.async_grad_inner(message, activations) # Send after grad self.transport.send_message(forward_message, sync=True) num_gradients += 1 if self.perform_optimizer_step(optimizer, num_gradients): optimizer.step() optimizer.zero_grad() transform_logger_object.check_and_save_weights(num_gradients) except StopIteration: break # remaining items for backward remaining_items = len(activations) for _ in range(remaining_items): message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE) args = message.args assert args.message_type is AsyncMessageType.Gradients if self.weight_prediction: optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) self.async_grad_inner(message, activations) num_gradients += 1 if self.perform_optimizer_step(optimizer, num_gradients): optimizer.step() optimizer.zero_grad() transform_logger_object.check_and_save_weights(num_gradients)