def send_skip_tensors(
            self, batch: Batch, i: int,
            skip_trackers: List[SkipTrackerThroughPotals]) -> None:
        ranks = get_pipeline_parallel_ranks()
        this_rank = torch.distributed.get_rank()

        for next_j, ns, name in self.skip_layout.copy_policy_by_src(
                self.group.rank()):
            life = skip_trackers[i].portals[(ns, name)].tensor_life
            loaded = skip_trackers[i].load(batch, ns, name)
            if loaded is not None:
                tensors = tuple([loaded])
            else:
                tensors = tuple()

            self.transport.send_message(
                PipeMessage(
                    this_rank,
                    ranks[next_j],
                    queue_name=SKIP_TENSOR_QUEUE,
                    args=(i, ns, name, life),
                    tensors=tensors,
                ),
                sync=True,
            )
Example #2
0
    def run_invocation(
        self,
        batch: Batch,
        partition: ModuleWrapper,
        skip_trackers: List[SkipTrackerThroughPotals],
        invocation: Invocation,
    ) -> Batch:
        """Actually run the forward pass for a given module, and send the result
        to the next stage in the pipeline if needed."""

        task = create_task(
            self.checkpoint_stop,
            batch.index,
            self.group.rank(),
            batch,
            partition.module,
            skip_trackers,
        )
        result = task.compute()
        task.finalize(result)

        if invocation.dest and invocation.dest.stage != invocation.this.stage:
            ranks = get_pipeline_parallel_ranks()
            dst_rank = ranks[invocation.dest.stage]
            result = self.send_async_message(dst_rank, result, invocation)
        return result
Example #3
0
    def async_send_inner(self, batch: Batch, index: int) -> Tuple[Batch, PipeMessage]:
        task = create_task_without_skip_trackers(
            self.checkpoint_stop, index, self.group.rank(), batch, self.partitions[0].module,
        )
        result = task.compute()
        task.finalize(result)

        ranks = get_pipeline_parallel_ranks()
        this_rank = torch.distributed.get_rank()

        body = AsyncMessageBody(
            AsyncMessageType.Activations,
            index,
            Location(this_rank, 0),
            Location(ranks[ranks.index(this_rank) + 1], 0),
            0,
        )
        message = PipeMessage(
            this_rank,
            ranks[ranks.index(this_rank) + 1],
            queue_name=EVENT_LOOP_ACTIVATIONS_QUEUE,
            args=body,
            tensors=tuple([*result]),
        )
        return result, message
Example #4
0
    def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
        ranks = get_pipeline_parallel_ranks()
        this_rank = torch.distributed.get_rank()
        body = AsyncMessageBody(
            AsyncMessageType.Gradients, ctx.index, source=ctx.args.dest, dest=ctx.args.source, order=ctx.args.order - 1
        )
        ctx.transport.send_message(
            PipeMessage(
                this_rank, ranks[ctx.args.source.stage], queue_name=ctx.queue_name, args=body, tensors=tuple(grad),
            ),
            sync=True,
        )

        tail_ctx = getattr(ctx, "tail_ctx", None)
        if tail_ctx:
            expected_gradients = tail_ctx.expected_gradients
            while expected_gradients > 0:
                message = ctx.transport.recv_message_header(ctx.queue_name)

                args: AsyncMessageBody = message.args
                assert args.message_type is AsyncMessageType.Gradients

                invocation = tail_ctx.invocations[args.order]
                expected_gradients -= tail_ctx.count_per_order[invocation.order]
                AsyncEventLoop.perform_backward_for_invocation(ctx.transport, message, tail_ctx.activations, invocation)

        return (None, None, None, None, None)
Example #5
0
    def event_loop_tail_across_minibatches(
        self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any
    ) -> None:
        # handles one epoch

        cur_rank = self.group.rank()
        N = len(get_pipeline_parallel_ranks())
        num_batches = len(lm_dataloader)
        lm_iter = enumerate(lm_dataloader)
        # last partition -> one forward / one backward -> no warmup
        count = 0
        num_gradients = 0
        activations = dict()

        log_interval = 1
        word_counter = 0
        total_loss = 0

        while True:
            try:
                start_time = time.time()
                microbatch_index, cur_batch = next(lm_iter)
                reqd_target = transform_logger_object.transform_target(cur_batch).to(self.input_device)

                # one forward
                message = self.transport.recv_message_header(EVENT_LOOP_ACTIVATIONS_QUEUE)
                args: AsyncMessageBody = message.args
                assert args.microbatch_index == count
                batch = self.get_batch_from_message(message, EVENT_LOOP_GRADIENTS_QUEUE)

                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
                task = create_task_without_skip_trackers(
                    self.checkpoint_stop, args.microbatch_index, self.group.rank(), batch, self.partitions[0].module,
                )
                output = task.compute()
                activations[args.microbatch_index] = output
                task.finalize(output)
                # one backward
                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)

                output_tensor = transform_logger_object.transform_output_before_loss(output.tensor)
                loss = criterion(output_tensor, reqd_target)
                loss.backward()
                count += 1
                num_gradients += 1

                if self.perform_optimizer_step(optimizer, num_gradients):
                    optimizer.step()
                    optimizer.zero_grad()
                    transform_logger_object.check_and_save_weights(num_gradients)

                transform_logger_object.log_loss(cur_batch, loss, count)
                del loss
                del activations[args.microbatch_index]
            except StopIteration:
                break
Example #6
0
    def event_loop_across_minibatches(self, lm_dataloader: DataLoader,
                                      criterion: nn.Module,
                                      optimizer: Optimizer,
                                      transform_logger_object: Any) -> None:
        activations: Dict[int, Batch] = dict()
        num_microbatch = len(lm_dataloader)
        num_activations = 0
        num_gradients = 0

        ranks = get_pipeline_parallel_ranks()  # for warmup phase
        N = len(ranks)
        cur_rank = torch.distributed.get_rank()

        # warmup phase (forward passes)
        # cur_rank worker will do (max_rank - cur_rank) forward passes
        n_warmup = ranks[-1] - cur_rank
        for _ in range(n_warmup):
            if self.weight_prediction:
                optimizer.update_weight_using_future_predictions(
                    cur_rank, N, forward=True)  # type: ignore
            message = self.event_loop_trunk_forward_helper(activations)
            self.transport.send_message(message, sync=True)
            num_activations += 1

        # common loop for remanining items in the warmup phase and steady phase
        while num_activations < num_microbatch:
            # 1 Forward
            if self.weight_prediction:
                optimizer.update_weight_using_future_predictions(
                    cur_rank, N, forward=True)  # type: ignore
            message = self.event_loop_trunk_forward_helper(activations)

            num_activations += 1
            # 1 Backward
            if self.weight_prediction:
                optimizer.update_weight_using_future_predictions(
                    cur_rank, N, forward=False)  # type: ignore
            self.event_loop_trunk_backward_helper(activations)
            num_gradients += 1
            if self.perform_optimizer_step(optimizer, num_gradients):
                optimizer.step()
                optimizer.zero_grad()
                transform_logger_object.check_and_save_weights(num_gradients)

            self.transport.send_message(message, sync=True)

        # remaining backwards
        remaining = len(activations)
        for _ in range(remaining):
            if self.weight_prediction:
                optimizer.update_weight_using_future_predictions(
                    cur_rank, N, forward=False)  # type: ignore
            self.event_loop_trunk_backward_helper(activations)
            num_gradients += 1
            if self.perform_optimizer_step(optimizer, num_gradients):
                optimizer.step()
                optimizer.zero_grad()
                transform_logger_object.check_and_save_weights(num_gradients)
    def forward(ctx, transport: Transport, input: List[Tensor],
                index: int) -> Tensors:
        ranks = get_pipeline_parallel_ranks()
        src_rank = torch.distributed.get_rank()
        dst_rank = ranks[ranks.index(src_rank) + 1]

        transport.send_message(
            PipeMessage(src_rank,
                        dst_rank,
                        queue_name=ACTIVATIONS_GRADS_QUEUE,
                        args=index,
                        tensors=tuple(input)), )
        return ()
 def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
     ranks = get_pipeline_parallel_ranks()
     this_rank = torch.distributed.get_rank()
     ctx.transport.send_message(
         PipeMessage(
             this_rank,
             ranks[ranks.index(this_rank) - 1],
             queue_name=ACTIVATIONS_GRADS_QUEUE,
             args=ctx.index,
             tensors=tuple(grad),
         ),
     )
     return (None, None, None, None, None)
    def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
        dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
        if dest == src:
            return
        ranks = get_pipeline_parallel_ranks()
        dst_rank = ranks[dest]
        if dst_rank == torch.distributed.get_rank():
            return

        if isinstance(grad, Tensor):
            grad = tuple([grad])
        self.transport.send_message(
            PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), sync=True,
        )
    def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch:
        batch = task.compute()

        rank = self.group.rank()

        if not self.final_stage:
            ranks = get_pipeline_parallel_ranks()
            this_rank = torch.distributed.get_rank()

            self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers)
            SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport, [*batch], i)

        for portal in skip_trackers[i].portals.values():
            portal.pipeline = self

        task.finalize(batch)

        return batch
Example #11
0
    def event_loop_head_across_minibatches(
        self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any
    ) -> None:
        # handles one epoch

        cur_rank = self.group.rank()
        N = len(get_pipeline_parallel_ranks())  # for warmup phase
        activations = dict()
        count = 0
        num_gradients = 0
        lm_iter = iter(lm_dataloader)

        # filling the pipeline: warmup  -> all N - 1 forward passes
        while True:
            try:
                cur_batch = next(lm_iter)
                reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device)
                batch = Batch(reqd_input, count)
                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
                activations[count], message = self.async_send_inner(batch, count)
                self.transport.send_message(message, sync=True)
                count += 1
                if count == N - 1:
                    break
            except StopIteration:
                break

        # steady state
        while True:
            try:
                # 1 forward pass
                cur_batch = next(lm_iter)
                reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device)
                batch = Batch(reqd_input, count)
                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
                activations[count], forward_message = self.async_send_inner(batch, count)
                count += 1

                # 1 backward pass
                message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE)
                args: AsyncMessageBody = message.args
                assert args.message_type is AsyncMessageType.Gradients
                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)
                self.async_grad_inner(message, activations)

                # Send after grad
                self.transport.send_message(forward_message, sync=True)

                num_gradients += 1
                if self.perform_optimizer_step(optimizer, num_gradients):
                    optimizer.step()
                    optimizer.zero_grad()
                    transform_logger_object.check_and_save_weights(num_gradients)

            except StopIteration:
                break

        # remaining items for backward
        remaining_items = len(activations)
        for _ in range(remaining_items):
            message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE)
            args = message.args
            assert args.message_type is AsyncMessageType.Gradients
            if self.weight_prediction:
                optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)
            self.async_grad_inner(message, activations)
            num_gradients += 1

            if self.perform_optimizer_step(optimizer, num_gradients):
                optimizer.step()
                optimizer.zero_grad()
                transform_logger_object.check_and_save_weights(num_gradients)