def test_epoch_end(self, outputs):
        if not outputs:
            return
        if parallel_state.is_pipeline_last_stage():
            # only the last pipeline parallel stages return loss
            averaged_loss = torch.stack(outputs).mean()
        else:
            averaged_loss = torch.tensor(0.0).cuda()

        # we can only log on one rank if it is rank zero so we broadcast from last rank
        torch.distributed.broadcast(averaged_loss, get_last_rank())
        self.log('test_loss', averaged_loss, prog_bar=True, rank_zero_only=True)
        self.log(
            'consumed_samples',
            self.compute_consumed_samples(self.trainer.global_step - self.init_global_step),
            rank_zero_only=True,
        )

        return averaged_loss
    def training_step(self, batch, batch_idx):
        """
            Our dataloaders produce a micro-batch and then we fetch
            a number of microbatches depending on the global batch size and model parallel size
            from the dataloader to produce a list of microbatches.
            Batch should be a list of microbatches and those microbatches should on CPU.
            Microbatches are then moved to GPU during the pipeline.
            The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions.
        """
        # we zero grads here because we also call backward in the apex fwd/bwd functions
        self._optimizer.zero_grad()

        # we prepare the micro batches for the apex fwd/bwd function
        batch_for_pipeline = self.process_global_batch(batch)
        encoder_seq_length = batch_for_pipeline[0].size(1)
        decoder_seq_length = batch_for_pipeline[1].size(1)
        tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.hidden_size]

        if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
            losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving(
                forward_step_func=self.get_forward_output_and_loss_func(),
                batch=batch_for_pipeline,
                model=self.enc_dec_model,
                forward_only=False,
                tensor_shape=tensor_shape,
                decoder_sequence_length=decoder_seq_length,
                dtype=self.autocast_dtype,
                grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
            )
        else:
            losses_reduced_per_micro_batch = forward_backward_no_pipelining(
                forward_step_func=self.get_forward_output_and_loss_func(),
                batch=batch_for_pipeline,
                model=self.enc_dec_model,
                forward_only=False,
                tensor_shape=tensor_shape,
                decoder_sequence_length=decoder_seq_length,
                dtype=self.autocast_dtype,
                grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
            )

        # only the last stages of the pipeline return losses
        if losses_reduced_per_micro_batch:
            # average loss across micro batches
            loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch]
            loss_tensor = torch.concat(loss_tensors_list)
            loss_mean = loss_tensor.mean()
        else:
            loss_mean = torch.tensor(0.0).cuda()

        # TODO: if we're not using pipeline, then we should do async allreduce (better perf)
        # in order to do this with O2, we need the async handler to be added to apex fwd/bwd function
        if self.megatron_amp_o2:
            # main grads are stored in the MainParamsOptimizer wrapper
            self._optimizer.allreduce_main_grads()  # @sangkug we think this is fine

            self.allreduce_word_and_position_embeddings()
        else:

            self.allreduce_gradients()  # @sangkug we think this is causing memory to blow up (hurts perf)

            self.allreduce_word_and_position_embeddings()

        ## logging
        # we can only log on one rank if it is rank zero so we broadcast from last rank
        # we can avoid this broadcast by updating the PTL log function to accept specific ranks
        torch.distributed.broadcast(loss_mean, get_last_rank())

        if self.cfg.precision == 16:
            loss_scale = self.trainer.precision_plugin.scaler._scale
            if loss_scale is not None:
                self.log('loss_scale', loss_scale)

        self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True)
        lr = self._optimizer.param_groups[0]['lr']
        self.log('lr', lr, rank_zero_only=True)
        self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True)
        # TODO: make sure compute_consumed_samples works for pipeline parallelism
        self.log(
            'consumed_samples',
            self.compute_consumed_samples(self.trainer.global_step - self.init_global_step),
            prog_bar=True,
            rank_zero_only=True,
        )
        return loss_mean
Beispiel #3
0
    def eval_epoch_end(self, outputs, mode):
        if not outputs:
            return
        if isinstance(outputs[0], dict):
            outputs = [outputs]

        loss_list = []
        bleu_score_list = []
        for dataloader_idx, output in enumerate(outputs):
            if parallel_state.is_pipeline_last_stage():
                # only the last pipeline parallel stages return loss
                averaged_loss = torch.stack([x['loss'] for x in output]).mean()
            else:
                averaged_loss = torch.tensor(0.0).to(self.device)

            # we can only log on one rank if it is rank zero so we broadcast from last rank
            torch.distributed.broadcast(averaged_loss, get_last_rank())

            # averaged_loss = average_losses_across_data_parallel_group([x['loss'] for x in output])
            inputs = list(itertools.chain(*[x['inputs'] for x in output]))
            translations = list(
                itertools.chain(*[x['translations'] for x in output]))
            ground_truths = list(
                itertools.chain(*[x['ground_truths'] for x in output]))
            assert len(translations) == len(inputs)
            assert len(translations) == len(ground_truths)

            # Gather translations and ground truths from all workers
            tr_gt_inp = [
                None
                for _ in range(parallel_state.get_data_parallel_world_size())
            ]
            # we also need to drop pairs where ground truth is an empty string
            torch.distributed.all_gather_object(
                tr_gt_inp,
                [(t, g, i)
                 for (t, g, i) in zip(translations, ground_truths, inputs)],
                group=parallel_state.get_data_parallel_group(),
            )
            if parallel_state.get_data_parallel_rank() == 0:
                _translations = []
                _ground_truths = []
                _inputs = []

                # Deduplicate sentences that may have been distributed across multiple data parallel ranks.
                gt_inp_set = set()
                for rank in range(
                        0, parallel_state.get_data_parallel_world_size()):
                    for t, g, i in tr_gt_inp[rank]:
                        if g + i not in gt_inp_set:
                            gt_inp_set.add(g + i)
                            _translations.append(t)
                            _ground_truths.append(g)
                            _inputs.append(i)

                if self.tgt_language in ['ja']:
                    sacre_bleu = corpus_bleu(_translations, [_ground_truths],
                                             tokenize="ja-mecab")
                elif self.tgt_language in ['zh']:
                    sacre_bleu = corpus_bleu(_translations, [_ground_truths],
                                             tokenize="zh")
                else:
                    sacre_bleu = corpus_bleu(_translations, [_ground_truths],
                                             tokenize="13a")

                bleu_score = sacre_bleu.score * parallel_state.get_data_parallel_world_size(
                )

                dataset_name = "Validation" if mode == 'val' else "Test"
                logging.info(
                    f"{dataset_name}, Dataloader index: {dataloader_idx}, Set size: {len(_translations)}"
                )
                logging.info(
                    f"{dataset_name}, Dataloader index: {dataloader_idx}, SacreBLEU = {bleu_score / parallel_state.get_data_parallel_world_size()}"
                )
                logging.info(
                    f"{dataset_name}, Dataloader index: {dataloader_idx}, Translation Examples:"
                )
                logging.info(
                    '============================================================'
                )
                for example_idx in range(0, 3):
                    random_index = random.randint(0, len(_translations) - 1)
                    logging.info("    " +
                                 '\u0332'.join(f"Example {example_idx}:"))
                    logging.info(f"    Input:        {_inputs[random_index]}")
                    logging.info(
                        f"    Prediction:   {_translations[random_index]}")
                    logging.info(
                        f"    Ground Truth: {_ground_truths[random_index]}")
                    logging.info(
                        '============================================================'
                    )

            else:
                bleu_score = 0.0

            loss_list.append(averaged_loss.cpu().numpy())
            bleu_score_list.append(bleu_score)
            if dataloader_idx == 0:
                self.log(f'{mode}_sacreBLEU', bleu_score, sync_dist=True)
                self.log(f'{mode}_loss', averaged_loss, prog_bar=True)
                if self.multilingual:
                    self._log_multilingual_bleu_and_loss(
                        dataloader_idx, bleu_score, averaged_loss, mode)
            else:
                if self.multilingual:
                    self._log_multilingual_bleu_and_loss(
                        dataloader_idx, bleu_score, averaged_loss, mode)
                else:
                    self.log(f'{mode}_sacreBLEU_dl_index_{dataloader_idx}',
                             bleu_score,
                             sync_dist=True)
                    self.log(f'{mode}_loss_dl_index_{dataloader_idx}',
                             averaged_loss,
                             prog_bar=False)

        if len(loss_list) > 1:
            self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True)
            self.log(f"{mode}_sacreBLEU_avg",
                     np.mean(bleu_score_list),
                     sync_dist=True)
Beispiel #4
0
    def training_step(self, batch, batch_idx):
        """
            Our dataloaders produce a micro-batch and then we fetch
            a number of microbatches depending on the global batch size and model parallel size
            from the dataloader to produce a list of microbatches.
            Batch should be a list of microbatches and those microbatches should on CPU.
            Microbatches are then moved to GPU during the pipeline.
            The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions.
        """

        # we zero grads here because we also call backward in the apex fwd/bwd functions
        self._optimizer.zero_grad()

        # we prepare the micro batches for the apex fwd/bwd function
        batch_for_pipeline = self.process_global_batch(batch)
        tensor_shape = [
            self.cfg.encoder_seq_length, self.cfg.micro_batch_size,
            self.cfg.hidden_size
        ]

        if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
            losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving(
                forward_step_func=self.get_forward_output_and_loss_func(),
                batch=batch_for_pipeline,
                model=self.model,
                forward_only=False,
                tensor_shape=tensor_shape,
                dtype=self.autocast_dtype,
                grad_scaler=self.trainer.precision_plugin.scaler
                if self.cfg.precision == 16 else None,
            )
        else:
            # no pipeline parallelism so we reduce grads asynchronously
            if self.megatron_amp_o2:
                custom_sync_context_handler = self._optimizer.no_sync
            else:
                # TODO: enable async grad all reduce for O1/autocast mixed precision training
                custom_sync_context_handler = None
            losses_reduced_per_micro_batch = forward_backward_no_pipelining(
                forward_step_func=self.get_forward_output_and_loss_func(),
                batch=batch_for_pipeline,
                model=self.model,
                forward_only=False,
                tensor_shape=tensor_shape,
                dtype=self.autocast_dtype,
                grad_scaler=self.trainer.precision_plugin.scaler
                if self.cfg.precision == 16 else None,
                custom_sync_context_handler=custom_sync_context_handler,
            )

        # only the last stages of the pipeline return losses
        if losses_reduced_per_micro_batch:
            # average loss across micro batches
            loss_tensors_list = [
                loss_reduced['avg']
                for loss_reduced in losses_reduced_per_micro_batch
            ]
            loss_tensor = torch.concat(loss_tensors_list)
            loss_mean = loss_tensor.mean()
        else:
            loss_mean = torch.tensor(0.0).cuda()

        if self.megatron_amp_o2:
            # when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously)
            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                # main grads are stored in the MainParamsOptimizer wrapper
                self._optimizer.allreduce_main_grads()
        else:
            # async grad allreduce is not currently implemented for O1/autocasting mixed precision training
            # so we allreduce gradients after the pipeline
            self.allreduce_gradients(
            )  # @sangkug we think this is causing memory to blow up (hurts perf)

        if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
            # when using pipeline parallelism the first and last stage must keep embeddings in sync
            self.allreduce_first_last_embeddings()

        # while async grad allreduce is enabled, bprop will keep moving forward without waiting for
        # the finish of async grad AR works. Hence, to guarantee the correctness of grads reduction,
        # we cannot start weight update until all async grad AR works are done.
        if self.megatron_amp_o2 and self.cfg.get(
                'pipeline_model_parallel_size', 1) == 1:
            torch.cuda.synchronize()

        ## logging
        # we can only log on one rank if it is rank zero so we broadcast from last rank
        # we can avoid this broadcast by updating the PTL log function to accept specific ranks
        torch.distributed.broadcast(loss_mean, get_last_rank())

        if self.cfg.precision == 16:
            loss_scale = self.trainer.precision_plugin.scaler._scale
            if loss_scale is not None:
                self.log('loss_scale', loss_scale)

        self.log('reduced_train_loss',
                 loss_mean,
                 prog_bar=True,
                 rank_zero_only=True)
        lr = self._optimizer.param_groups[0]['lr']
        self.log('lr', lr, rank_zero_only=True)
        self.log('global_step',
                 self.trainer.global_step,
                 prog_bar=True,
                 rank_zero_only=True)
        # TODO: make sure compute_consumed_samples works for pipeline parallelism
        self.log(
            'consumed_samples',
            self.compute_consumed_samples(self.trainer.global_step -
                                          self.init_global_step),
            prog_bar=True,
            rank_zero_only=True,
        )

        return loss_mean
Beispiel #5
0
    def validation_epoch_end(self, outputs):
        averaged_loss = average_losses_across_data_parallel_group(outputs)

        # we can only log on one rank if it is rank zero so we broadcast from last rank
        torch.distributed.broadcast(averaged_loss, get_last_rank())
        self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True)