Example #1
0
    def event_loop_tail_across_minibatches(
        self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any
    ) -> None:
        # handles one epoch

        cur_rank = self.group.rank()
        N = len(get_pipeline_parallel_ranks())
        num_batches = len(lm_dataloader)
        lm_iter = enumerate(lm_dataloader)
        # last partition -> one forward / one backward -> no warmup
        count = 0
        num_gradients = 0
        activations = dict()

        log_interval = 1
        word_counter = 0
        total_loss = 0

        while True:
            try:
                start_time = time.time()
                microbatch_index, cur_batch = next(lm_iter)
                reqd_target = transform_logger_object.transform_target(cur_batch).to(self.input_device)

                # one forward
                message = self.transport.recv_message_header(EVENT_LOOP_ACTIVATIONS_QUEUE)
                args: AsyncMessageBody = message.args
                assert args.microbatch_index == count
                batch = self.get_batch_from_message(message, EVENT_LOOP_GRADIENTS_QUEUE)

                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
                task = create_task_without_skip_trackers(
                    self.checkpoint_stop, args.microbatch_index, self.group.rank(), batch, self.partitions[0].module,
                )
                output = task.compute()
                activations[args.microbatch_index] = output
                task.finalize(output)
                # one backward
                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)

                output_tensor = transform_logger_object.transform_output_before_loss(output.tensor)
                loss = criterion(output_tensor, reqd_target)
                loss.backward()
                count += 1
                num_gradients += 1

                if self.perform_optimizer_step(optimizer, num_gradients):
                    optimizer.step()
                    optimizer.zero_grad()
                    transform_logger_object.check_and_save_weights(num_gradients)

                transform_logger_object.log_loss(cur_batch, loss, count)
                del loss
                del activations[args.microbatch_index]
            except StopIteration:
                break
Example #2
0
    def event_loop_across_minibatches(self, lm_dataloader: DataLoader,
                                      criterion: nn.Module,
                                      optimizer: Optimizer,
                                      transform_logger_object: Any) -> None:
        activations: Dict[int, Batch] = dict()
        num_microbatch = len(lm_dataloader)
        num_activations = 0
        num_gradients = 0

        ranks = get_pipeline_parallel_ranks()  # for warmup phase
        N = len(ranks)
        cur_rank = torch.distributed.get_rank()

        # warmup phase (forward passes)
        # cur_rank worker will do (max_rank - cur_rank) forward passes
        n_warmup = ranks[-1] - cur_rank
        for _ in range(n_warmup):
            if self.weight_prediction:
                optimizer.update_weight_using_future_predictions(
                    cur_rank, N, forward=True)  # type: ignore
            message = self.event_loop_trunk_forward_helper(activations)
            self.transport.send_message(message, sync=True)
            num_activations += 1

        # common loop for remanining items in the warmup phase and steady phase
        while num_activations < num_microbatch:
            # 1 Forward
            if self.weight_prediction:
                optimizer.update_weight_using_future_predictions(
                    cur_rank, N, forward=True)  # type: ignore
            message = self.event_loop_trunk_forward_helper(activations)

            num_activations += 1
            # 1 Backward
            if self.weight_prediction:
                optimizer.update_weight_using_future_predictions(
                    cur_rank, N, forward=False)  # type: ignore
            self.event_loop_trunk_backward_helper(activations)
            num_gradients += 1
            if self.perform_optimizer_step(optimizer, num_gradients):
                optimizer.step()
                optimizer.zero_grad()
                transform_logger_object.check_and_save_weights(num_gradients)

            self.transport.send_message(message, sync=True)

        # remaining backwards
        remaining = len(activations)
        for _ in range(remaining):
            if self.weight_prediction:
                optimizer.update_weight_using_future_predictions(
                    cur_rank, N, forward=False)  # type: ignore
            self.event_loop_trunk_backward_helper(activations)
            num_gradients += 1
            if self.perform_optimizer_step(optimizer, num_gradients):
                optimizer.step()
                optimizer.zero_grad()
                transform_logger_object.check_and_save_weights(num_gradients)
Example #3
0
    def event_loop_head_across_minibatches(
        self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any
    ) -> None:
        # handles one epoch

        cur_rank = self.group.rank()
        N = len(get_pipeline_parallel_ranks())  # for warmup phase
        activations = dict()
        count = 0
        num_gradients = 0
        lm_iter = iter(lm_dataloader)

        # filling the pipeline: warmup  -> all N - 1 forward passes
        while True:
            try:
                cur_batch = next(lm_iter)
                reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device)
                batch = Batch(reqd_input, count)
                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
                activations[count], message = self.async_send_inner(batch, count)
                self.transport.send_message(message, sync=True)
                count += 1
                if count == N - 1:
                    break
            except StopIteration:
                break

        # steady state
        while True:
            try:
                # 1 forward pass
                cur_batch = next(lm_iter)
                reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device)
                batch = Batch(reqd_input, count)
                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
                activations[count], forward_message = self.async_send_inner(batch, count)
                count += 1

                # 1 backward pass
                message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE)
                args: AsyncMessageBody = message.args
                assert args.message_type is AsyncMessageType.Gradients
                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)
                self.async_grad_inner(message, activations)

                # Send after grad
                self.transport.send_message(forward_message, sync=True)

                num_gradients += 1
                if self.perform_optimizer_step(optimizer, num_gradients):
                    optimizer.step()
                    optimizer.zero_grad()
                    transform_logger_object.check_and_save_weights(num_gradients)

            except StopIteration:
                break

        # remaining items for backward
        remaining_items = len(activations)
        for _ in range(remaining_items):
            message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE)
            args = message.args
            assert args.message_type is AsyncMessageType.Gradients
            if self.weight_prediction:
                optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)
            self.async_grad_inner(message, activations)
            num_gradients += 1

            if self.perform_optimizer_step(optimizer, num_gradients):
                optimizer.step()
                optimizer.zero_grad()
                transform_logger_object.check_and_save_weights(num_gradients)