def test_epoch_end(self, outputs): if not outputs: return if parallel_state.is_pipeline_last_stage(): # only the last pipeline parallel stages return loss averaged_loss = torch.stack(outputs).mean() else: averaged_loss = torch.tensor(0.0).cuda() # we can only log on one rank if it is rank zero so we broadcast from last rank torch.distributed.broadcast(averaged_loss, get_last_rank()) self.log('test_loss', averaged_loss, prog_bar=True, rank_zero_only=True) self.log( 'consumed_samples', self.compute_consumed_samples(self.trainer.global_step - self.init_global_step), rank_zero_only=True, ) return averaged_loss
def training_step(self, batch, batch_idx): """ Our dataloaders produce a micro-batch and then we fetch a number of microbatches depending on the global batch size and model parallel size from the dataloader to produce a list of microbatches. Batch should be a list of microbatches and those microbatches should on CPU. Microbatches are then moved to GPU during the pipeline. The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ # we zero grads here because we also call backward in the apex fwd/bwd functions self._optimizer.zero_grad() # we prepare the micro batches for the apex fwd/bwd function batch_for_pipeline = self.process_global_batch(batch) encoder_seq_length = batch_for_pipeline[0].size(1) decoder_seq_length = batch_for_pipeline[1].size(1) tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.hidden_size] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_and_loss_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=False, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, ) else: losses_reduced_per_micro_batch = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_and_loss_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=False, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, ) # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: # average loss across micro batches loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] loss_tensor = torch.concat(loss_tensors_list) loss_mean = loss_tensor.mean() else: loss_mean = torch.tensor(0.0).cuda() # TODO: if we're not using pipeline, then we should do async allreduce (better perf) # in order to do this with O2, we need the async handler to be added to apex fwd/bwd function if self.megatron_amp_o2: # main grads are stored in the MainParamsOptimizer wrapper self._optimizer.allreduce_main_grads() # @sangkug we think this is fine self.allreduce_word_and_position_embeddings() else: self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) self.allreduce_word_and_position_embeddings() ## logging # we can only log on one rank if it is rank zero so we broadcast from last rank # we can avoid this broadcast by updating the PTL log function to accept specific ranks torch.distributed.broadcast(loss_mean, get_last_rank()) if self.cfg.precision == 16: loss_scale = self.trainer.precision_plugin.scaler._scale if loss_scale is not None: self.log('loss_scale', loss_scale) self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr, rank_zero_only=True) self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True) # TODO: make sure compute_consumed_samples works for pipeline parallelism self.log( 'consumed_samples', self.compute_consumed_samples(self.trainer.global_step - self.init_global_step), prog_bar=True, rank_zero_only=True, ) return loss_mean
def eval_epoch_end(self, outputs, mode): if not outputs: return if isinstance(outputs[0], dict): outputs = [outputs] loss_list = [] bleu_score_list = [] for dataloader_idx, output in enumerate(outputs): if parallel_state.is_pipeline_last_stage(): # only the last pipeline parallel stages return loss averaged_loss = torch.stack([x['loss'] for x in output]).mean() else: averaged_loss = torch.tensor(0.0).to(self.device) # we can only log on one rank if it is rank zero so we broadcast from last rank torch.distributed.broadcast(averaged_loss, get_last_rank()) # averaged_loss = average_losses_across_data_parallel_group([x['loss'] for x in output]) inputs = list(itertools.chain(*[x['inputs'] for x in output])) translations = list( itertools.chain(*[x['translations'] for x in output])) ground_truths = list( itertools.chain(*[x['ground_truths'] for x in output])) assert len(translations) == len(inputs) assert len(translations) == len(ground_truths) # Gather translations and ground truths from all workers tr_gt_inp = [ None for _ in range(parallel_state.get_data_parallel_world_size()) ] # we also need to drop pairs where ground truth is an empty string torch.distributed.all_gather_object( tr_gt_inp, [(t, g, i) for (t, g, i) in zip(translations, ground_truths, inputs)], group=parallel_state.get_data_parallel_group(), ) if parallel_state.get_data_parallel_rank() == 0: _translations = [] _ground_truths = [] _inputs = [] # Deduplicate sentences that may have been distributed across multiple data parallel ranks. gt_inp_set = set() for rank in range( 0, parallel_state.get_data_parallel_world_size()): for t, g, i in tr_gt_inp[rank]: if g + i not in gt_inp_set: gt_inp_set.add(g + i) _translations.append(t) _ground_truths.append(g) _inputs.append(i) if self.tgt_language in ['ja']: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="ja-mecab") elif self.tgt_language in ['zh']: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="zh") else: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="13a") bleu_score = sacre_bleu.score * parallel_state.get_data_parallel_world_size( ) dataset_name = "Validation" if mode == 'val' else "Test" logging.info( f"{dataset_name}, Dataloader index: {dataloader_idx}, Set size: {len(_translations)}" ) logging.info( f"{dataset_name}, Dataloader index: {dataloader_idx}, SacreBLEU = {bleu_score / parallel_state.get_data_parallel_world_size()}" ) logging.info( f"{dataset_name}, Dataloader index: {dataloader_idx}, Translation Examples:" ) logging.info( '============================================================' ) for example_idx in range(0, 3): random_index = random.randint(0, len(_translations) - 1) logging.info(" " + '\u0332'.join(f"Example {example_idx}:")) logging.info(f" Input: {_inputs[random_index]}") logging.info( f" Prediction: {_translations[random_index]}") logging.info( f" Ground Truth: {_ground_truths[random_index]}") logging.info( '============================================================' ) else: bleu_score = 0.0 loss_list.append(averaged_loss.cpu().numpy()) bleu_score_list.append(bleu_score) if dataloader_idx == 0: self.log(f'{mode}_sacreBLEU', bleu_score, sync_dist=True) self.log(f'{mode}_loss', averaged_loss, prog_bar=True) if self.multilingual: self._log_multilingual_bleu_and_loss( dataloader_idx, bleu_score, averaged_loss, mode) else: if self.multilingual: self._log_multilingual_bleu_and_loss( dataloader_idx, bleu_score, averaged_loss, mode) else: self.log(f'{mode}_sacreBLEU_dl_index_{dataloader_idx}', bleu_score, sync_dist=True) self.log(f'{mode}_loss_dl_index_{dataloader_idx}', averaged_loss, prog_bar=False) if len(loss_list) > 1: self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True) self.log(f"{mode}_sacreBLEU_avg", np.mean(bleu_score_list), sync_dist=True)
def training_step(self, batch, batch_idx): """ Our dataloaders produce a micro-batch and then we fetch a number of microbatches depending on the global batch size and model parallel size from the dataloader to produce a list of microbatches. Batch should be a list of microbatches and those microbatches should on CPU. Microbatches are then moved to GPU during the pipeline. The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ # we zero grads here because we also call backward in the apex fwd/bwd functions self._optimizer.zero_grad() # we prepare the micro batches for the apex fwd/bwd function batch_for_pipeline = self.process_global_batch(batch) tensor_shape = [ self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size ] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_and_loss_func(), batch=batch_for_pipeline, model=self.model, forward_only=False, tensor_shape=tensor_shape, dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, ) else: # no pipeline parallelism so we reduce grads asynchronously if self.megatron_amp_o2: custom_sync_context_handler = self._optimizer.no_sync else: # TODO: enable async grad all reduce for O1/autocast mixed precision training custom_sync_context_handler = None losses_reduced_per_micro_batch = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_and_loss_func(), batch=batch_for_pipeline, model=self.model, forward_only=False, tensor_shape=tensor_shape, dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, custom_sync_context_handler=custom_sync_context_handler, ) # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: # average loss across micro batches loss_tensors_list = [ loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch ] loss_tensor = torch.concat(loss_tensors_list) loss_mean = loss_tensor.mean() else: loss_mean = torch.tensor(0.0).cuda() if self.megatron_amp_o2: # when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously) if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # main grads are stored in the MainParamsOptimizer wrapper self._optimizer.allreduce_main_grads() else: # async grad allreduce is not currently implemented for O1/autocasting mixed precision training # so we allreduce gradients after the pipeline self.allreduce_gradients( ) # @sangkug we think this is causing memory to blow up (hurts perf) if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # when using pipeline parallelism the first and last stage must keep embeddings in sync self.allreduce_first_last_embeddings() # while async grad allreduce is enabled, bprop will keep moving forward without waiting for # the finish of async grad AR works. Hence, to guarantee the correctness of grads reduction, # we cannot start weight update until all async grad AR works are done. if self.megatron_amp_o2 and self.cfg.get( 'pipeline_model_parallel_size', 1) == 1: torch.cuda.synchronize() ## logging # we can only log on one rank if it is rank zero so we broadcast from last rank # we can avoid this broadcast by updating the PTL log function to accept specific ranks torch.distributed.broadcast(loss_mean, get_last_rank()) if self.cfg.precision == 16: loss_scale = self.trainer.precision_plugin.scaler._scale if loss_scale is not None: self.log('loss_scale', loss_scale) self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr, rank_zero_only=True) self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True) # TODO: make sure compute_consumed_samples works for pipeline parallelism self.log( 'consumed_samples', self.compute_consumed_samples(self.trainer.global_step - self.init_global_step), prog_bar=True, rank_zero_only=True, ) return loss_mean
def validation_epoch_end(self, outputs): averaged_loss = average_losses_across_data_parallel_group(outputs) # we can only log on one rank if it is rank zero so we broadcast from last rank torch.distributed.broadcast(averaged_loss, get_last_rank()) self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True)