def on_validation_epoch_end(self): app_state = AppState() if hasattr(self, "_train_ds"): _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.train_ds.global_batch_size, micro_batch_size=self.cfg.data.train_ds.micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) # When running `trainer.validate()`, the training dataset is not available. else: logging.warning( 'No training data found, reconfiguring microbatches based on validation batch sizes.' ) _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.validation_ds. global_batch_size, micro_batch_size=self.cfg.data.validation_ds.micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) return super().on_validation_epoch_end()
def inference_step(self, batch, batch_idx, mode, dataloader_idx=0): batch_has_lang_information = len(batch[0]) == 7 micro_batch_size = batch[0]['text_enc'].size(0) # This should happen only on the last batch of the dataset. if micro_batch_size != self.cfg.data.validation_ds.micro_batch_size: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size * parallel_state.get_data_parallel_world_size() * get_num_microbatches(), micro_batch_size=micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) # At this point processed_batch is a list of dictionaries where eatch dict is a microbatch. # After the process_global_batch call, processed_batch will be a single dictionary containing the global batch. # This is required since the parent class expects a single global batch dictioanry. processed_batch = self._process_global_batch(batch) # Call parent validation step to get the loss. # NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI, this will be ignored in the parent class. loss = super().validation_step(processed_batch, batch_idx) predicted_token_ids, _ = self.decode( tokens_enc=processed_batch['text_enc'], enc_mask=processed_batch['enc_mask'], num_tokens_to_generate=30 ) # Special ids to text function to handle stripping <eos> and special tokens with sentencepiece tokenizers. preds_text = self.ids_to_text(predicted_token_ids) labels_text = self.ids_to_text(processed_batch['labels']) input_text = self.ids_to_text(processed_batch['text_enc']) if not batch_has_lang_information: categories = [None] * len(preds_text) else: categories = processed_batch['lang'] metric = self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx] assert len(categories) == len(preds_text) == len(labels_text) for _, (pred, label, category) in enumerate(zip(preds_text, labels_text, categories)): # To compute metrics like pearson or spearman correlation, we need to cast the predicted string and labels to floats. pred, label = self.cast_for_metric( pred, label, self.val_metric_name if mode == 'validation' else self.test_metric_name ) if batch_has_lang_information: _ = metric(pred, label, category) else: _ = metric(pred, label) return { 'loss': loss, 'preds': preds_text, 'labels': labels_text, 'categories': categories, 'inputs': input_text, }
def eval_step(self, batch, batch_idx, dataloader_idx, data_cfg): # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1. batch = [[x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch] for microbatch in batch] batch = self.process_global_batch_for_tarred_datasets(batch) if data_cfg.dataset_type in ['tarred', 'text']: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=batch['text_enc'].size(0) * parallel_state.get_data_parallel_world_size(), micro_batch_size=batch['text_enc'].size(0), data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) # This returns the averaged loss across data-parallel groups. reduced_loss = super().validation_step(batch, batch_idx) tokens_enc, labels, enc_mask = batch['text_enc'], batch[ 'labels'], batch['enc_mask'] predicted_tokens_ids, _ = self.decode( tokens_enc, enc_mask, tokens_enc.size(1) + self._cfg. max_generation_delta, # Generate up to src-length + max generation delta. TODO: Implement better stopping when everything hits <EOS>. tokenizer=self.decoder_tokenizer, ) if self.multilingual: source_processor = self.source_processor_list[dataloader_idx] target_processor = self.target_processor_list[dataloader_idx] else: source_processor = self.source_processor target_processor = self.target_processor # Post-process the translations and inputs to log. preds = self.postprocess_outputs( outputs=predicted_tokens_ids, tokenizer=self.decoder_tokenizer, processor=target_processor, ) labels = self.postprocess_outputs( outputs=labels, tokenizer=self.decoder_tokenizer, processor=target_processor, ) encoder_inputs = self.postprocess_outputs( outputs=tokens_enc, tokenizer=self.encoder_tokenizer, processor=source_processor, ) return { 'inputs': encoder_inputs, 'translations': preds, 'ground_truths': labels, 'loss': reduced_loss, }
def on_validation_epoch_start(self): app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.validation_ds.global_batch_size, micro_batch_size=self.cfg.data.validation_ds.micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return super().on_validation_epoch_start()
def on_validation_epoch_end(self): app_state = AppState() if hasattr(self, "_train_ds"): _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self._cfg.train_ds.global_batch_size, micro_batch_size=self._cfg.train_ds.micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size( ), )
def _test(self, rampup_batch_size: Optional[List[int]]) -> None: for data_parallel_size in range(1, self.world_size + 1): expected_global_batch_size = self.GLOBAL_BATCH_SIZE expected_micro_batch_size = self.MICRO_BATCH_SIZE if rampup_batch_size: expected_global_batch_size = rampup_batch_size[0] num_consumed_samples = 0 step_of_global_batch_size = rampup_batch_size[1] threshold = rampup_batch_size[2] if data_parallel_size > 1 and data_parallel_size % 2 != 0: continue if self.world_size % data_parallel_size != 0: continue with self.subTest(data_parallel_size=data_parallel_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=self.world_size // data_parallel_size, pipeline_model_parallel_size_=1, ) self.assertEqual(data_parallel_size, parallel_state.get_data_parallel_world_size()) _reconfigure_microbatch_calculator( self.rank, rampup_batch_size, self.GLOBAL_BATCH_SIZE, self.MICRO_BATCH_SIZE, data_parallel_size, ) self.assertEqual(get_micro_batch_size(), expected_micro_batch_size) self.assertEqual( get_num_microbatches(), expected_global_batch_size / expected_micro_batch_size / data_parallel_size) current_global_batch_size = get_current_global_batch_size() self.assertEqual(current_global_batch_size, expected_global_batch_size) # Make sure `global_batch_size` equals to the final global batch size after # certain number of updates. if rampup_batch_size: update_num_microbatches(current_global_batch_size) for i in range(100): current_global_batch_size = get_current_global_batch_size( ) update_num_microbatches(current_global_batch_size) current_global_batch_size = get_current_global_batch_size() self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE) parallel_state.destroy_model_parallel()
def training_step(self, batch, batch_idx): # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1. batch = [[x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch] for microbatch in batch] batch = self.process_global_batch_for_tarred_datasets(batch) app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=batch['text_enc'].size(0) * parallel_state.get_data_parallel_world_size(), micro_batch_size=batch['text_enc'].size(0), data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return super().training_step(batch, batch_idx)
def training_step(self, batch, batch_idx): micro_batch_size = batch[0]['text_enc'].size(0) # This should happen only on the last batch of the dataset. if micro_batch_size != self.cfg.data.train_ds.micro_batch_size: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size * parallel_state.get_data_parallel_world_size() * get_num_microbatches(), micro_batch_size=micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) return super().training_step(batch, batch_idx)
def training_step(self, batch, batch_idx): micro_batch_size = batch[0]['text_enc'].size(0) # This should happen only on the last batch of the dataset. if micro_batch_size != self.cfg.data.train_ds.micro_batch_size: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size * parallel_state.get_data_parallel_world_size() * get_num_microbatches(), micro_batch_size=micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) # At this point batch is a list of dictionaries where eatch dict is a microbatch. # After the process_global_batch call, batch will be a single dictionary containing the global batch. # This is required since the parent class expects a single global batch dictioanry. batch = self._process_global_batch(batch) return super().training_step(batch, batch_idx)
def training_step(self, batch, batch_idx): # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1. if self._cfg.train_ds.dataset_type in ['tarred', 'text']: batch = [[ x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch ] for microbatch in batch] batch = self.process_global_batch_for_tarred_datasets(batch) elif (self._cfg.train_ds.dataset_type in ['bin_memmap', 'text_memmap'] and self._cfg.train_ds.get("sampler", "distributed") == 'distributed'): batch = self._process_global_batch_without_megatron_batch_sampler( batch, tokenizer=self.encoder_tokenizer) if self._cfg.train_ds.dataset_type in ['tarred', 'text']: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=batch['text_enc'].size(0) * parallel_state.get_data_parallel_world_size(), micro_batch_size=batch['text_enc'].size(0), data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) return super().training_step(batch, batch_idx)
def tab_sample_sequence_batch( model, context_tokens, context_lengths, attention_mask, position_ids, tokens_to_generate, all_probs=True, type_ids=None, temperature=None, ): app_state = AppState() micro_batch_size = context_tokens.shape[0] _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size, micro_batch_size=micro_batch_size, data_parallel_size=1, ) tokenizer = model.tokenizer sizes = tokenizer.code_column.sizes tokens_per_row = sum(sizes) + 1 columns = tokenizer.code_column.columns num_columns = len(columns) tokenid_range = [] for i in range(num_columns): tokenid_range.extend(tokenizer.code_column.get_range(i)) model.eval() with torch.no_grad(): context_length = context_lengths.min().item() context = context_tokens[:, :context_length] # the context may start in the middle of the row, # calculate the offset according to the position of '\n' or '<|endoftext|>' positions = torch.where(context == tokenizer.eor)[1] if len(positions) == 0: positions = torch.where(context == tokenizer.eod)[1] if len(positions) != 0: max_position = positions.max().item() # TODO, need to make sure context of different batch have the same offset lengths") # otherwise, need to calculate offset per batch_id offset = (context_length - max_position - 1) % tokens_per_row else: offset = 0 eod_id = tokenizer.eos_id counter = 0 batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens output_logits = None # Generate enough tokens for the longest sequence maxlen = tokens_to_generate + context_lengths.max().item() if maxlen > model.cfg.encoder_seq_length: maxlen = model.cfg.encoder_seq_length lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length < maxlen: # types2use = None if counter == 0: # Allocate memory for the entire context. set_inference_key_value_memory = True tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, :context_length] else: # Set this to false so the memory is not reallocated. set_inference_key_value_memory = False tokens2use = tokens[:, context_length - 1].view(batch_size, -1) positions2use = position_ids[:, context_length - 1].view( batch_size, -1) # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, context_length - 1].view(batch_size, -1) # micro_batch_size = 2 attention_mask_repeat = torch.concat( [attention_mask for _ in range(micro_batch_size)]) setkey_value_array = torch.tensor( [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device()) len_array = torch.tensor([maxlen] * micro_batch_size, device=torch.cuda.current_device()) batch = [ tokens2use, attention_mask_repeat, positions2use, setkey_value_array, len_array ] tensor_shape = [ tokens2use.shape[1], micro_batch_size, model.cfg.hidden_size ] output = forward_step(model, batch, tensor_shape) if parallel_state.is_pipeline_last_stage(): output = output[0]['logits'].float() output = tensor_parallel.gather_from_tensor_model_parallel_region( output) assert output is not None output = output.float() logits = output[:, -1].view(batch_size, -1).contiguous() token_in_row = (counter + offset) % tokens_per_row logits = logits.float() logits /= temperature if token_in_row == tokens_per_row - 1: # line break eor_id = tokenizer.eor eod_id = tokenizer.eos_id min_id = min(eor_id, eod_id) max_id = max(eor_id, eod_id) + 1 logits = tab_logits(logits, min_id, max_id) else: # limit the range min_id, max_id = tokenid_range[token_in_row] logits = tab_logits(logits, min_id, max_id) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length # Clamp the out of vocabulary tokens. prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) new_tokens = switch(tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens if output_logits is None: output_context = F.log_softmax( output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, 1:context_length + 1], 2) output_logits = torch.gather(output_context, 2, indices).squeeze(2) if all_probs: full_logits = output_context else: output_context = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2) new_output_logits = torch.gather(output_context, 2, indices).squeeze(2) # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat( [output_logits, new_output_logits], 1) if all_probs: full_logits = torch.cat([full_logits, output_context], 1) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) done_token = (prev == eod_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) if all_probs: yield tokens, lengths, output_logits, full_logits else: yield tokens, lengths, output_logits, None else: if parallel_state.is_pipeline_first_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank( ) group = parallel_state.get_embedding_group() new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None, None, None else: yield None, None, None, None done = torch.cuda.ByteTensor([0]) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) context_length += 1 counter += 1 if done: break
def _forward_backward_test_impl( self, forward_only: bool, fwd_bwd_func: FwdStepFunc, pipeline_model_parallel_world_size: Optional[int], virtual_pipeline_model_parallel_size: Optional[int], async_comm: bool = False, *, default_backend: Optional[str] = None, p2p_backend: Optional[str] = None, ) -> None: if fwd_bwd_func == _forward_backward_pipelining_with_interleaving: self.assertIsNotNone(virtual_pipeline_model_parallel_size) self.assertGreater(virtual_pipeline_model_parallel_size, 1) dtype_options = self.dtypes or [torch.float32, torch.double ] + _get_autocast_dtypes() for dtype, deallocate_pipeline_outputs in itertools.product( dtype_options, self.deallocate_options, ): grad_scaler = (torch.cuda.amp.GradScaler( init_scale=4.0) if dtype == torch.half else None) (tensor_model_parallel_world_size, data_parallel_size, pipeline_model_parallel_world_size ) = _get_default_world_sizes_model_parallel_world_size( pipeline_model_parallel_world_size) parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, pipeline_model_parallel_size_= pipeline_model_parallel_world_size, virtual_pipeline_model_parallel_size_= virtual_pipeline_model_parallel_size, default_backend=default_backend, p2p_backend=p2p_backend, ) pp_utils._reconfigure_microbatch_calculator( rank=parallel_state.get_tensor_model_parallel_rank(), rampup_batch_size=None, global_batch_size=self.GLOBAL_BATCH_SIZE, micro_batch_size=self.MICRO_BATCH_SIZE, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) global_batch_shape = ( self.GLOBAL_BATCH_SIZE // parallel_state.get_data_parallel_world_size(), self.HIDDEN_SIZE, self.HIDDEN_SIZE, ) batch = None if parallel_state.is_pipeline_first_stage(): batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), ) model = build_model( testing_utils.model_provider_func, # Use DDP only when it's better to have wrap_with_ddp=data_parallel_size > 1, virtual_pipeline_model_parallel_size= virtual_pipeline_model_parallel_size, hidden_size=self.HIDDEN_SIZE, ) offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0 for idx, model_module in enumerate(model): model_module = model_module.to(dtype) model_module.apply(get_init_weights_func(idx * offset)) _param_groups = _get_params_for_weight_decay_optimization(model) optimizer = torch.optim.Adam(_param_groups, lr=1e-3) pp_utils.update_num_microbatches(0) loss = fwd_bwd_func( testing_utils.fwd_step_func, batch, model, forward_only=forward_only, # `tensor_shape` is the shape of micro batch. tensor_shape=( self.MICRO_BATCH_SIZE, self.HIDDEN_SIZE, self.HIDDEN_SIZE, ), dtype=dtype, async_comm=async_comm, grad_scaler=grad_scaler, deallocate_pipeline_output=deallocate_pipeline_outputs, ) if dtype == torch.double: hidden_size = self.HIDDEN_SIZE microbatch_size = self.MICRO_BATCH_SIZE total_layers = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None: total_layers *= virtual_pipeline_model_parallel_size target_loss, target_model = get_target_loss_and_model( global_batch_shape, hidden_size, total_layers) for loss_item in loss: x = loss_item['avg'] torch.testing.assert_close(x.item() / microbatch_size, target_loss.item()) if not forward_only: for vm_id, model_module in enumerate(model): params = list(model_module.parameters()) rank = params[0].get_device() offset = pipeline_model_parallel_world_size param_id = rank // data_parallel_size + vm_id * offset target_params = target_model[param_id] torch.testing.assert_close(params[0].cpu(), target_params[0]) torch.testing.assert_close(params[1].cpu(), target_params[1]) torch.testing.assert_close( params[0].grad.cpu() / microbatch_size, target_params[0].grad) torch.testing.assert_close( params[1].grad.cpu() / microbatch_size, target_params[1].grad) if not forward_only: for m in model: for p in m.parameters(): self.assertIsNotNone(p.grad) optimizer.step() optimizer.zero_grad(set_to_none=True) parallel_state.destroy_model_parallel()
def complete(self, request: Dict): """ Autoregressively invokes language model in the inference mode Args: request: Dictionary with the following fields * prompt: a string which text the model should complete. * tokens_to_generate: how many tokens to generate while doing prompt completion. Returns: response: A python dictionary with the following fields * prompt: original text of the prompt * tokenized_prompt: list of (str) tokens from prompt * completion: a python dictionary with the following subfields: * tokens: a list of triples (token, token_id, log_prob) comprising completion * text: completion text (as a single string) """ app_state = AppState() # The complete method only works with global batch = micro batch size = data parallel size = 1. _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=1, micro_batch_size=1, data_parallel_size=1, ) app_state = AppState() response = {} self.freeze() # naive greedy slow loop # TODO: add option for BeamSearchDecoder response['prompt'] = request['prompt'][0] response['completion'] = {} tokens_enc = request['masked_sample'] response['masked_input'] = ' '.join(self.tokenizer.ids_to_tokens(tokens_enc[0].cpu().numpy().tolist())) enc_mask = tokens_enc != self.tokenizer.pad_id predicted_tokens_ids, log_probs = self.decode(tokens_enc, enc_mask, int(request['tokens_to_generate'])) predicted_tokens_ids = predicted_tokens_ids.cpu().numpy()[0].tolist() log_probs = log_probs.cpu().numpy()[0].tolist() if self.tokenizer.eos_id in predicted_tokens_ids: idx = predicted_tokens_ids.index(self.tokenizer.eos_id) predicted_tokens_ids = predicted_tokens_ids[:idx] else: predicted_tokens_ids = [id for id in predicted_tokens_ids if id != self.tokenizer.pad_id] if self.tokenizer.eos_id in predicted_tokens_ids: idx = predicted_tokens_ids.index(self.tokenizer.eos_id) predicted_tokens_ids = predicted_tokens_ids[:idx] # Legacy sentencepiece detokenization still preserves special tokens which messes up exact string match. if hasattr(self.tokenizer, 'special_token_to_id'): predicted_tokens_ids = [ id for id in predicted_tokens_ids if id not in self.tokenizer.special_token_to_id.values() ] predicted_tokens_dec = self.tokenizer.ids_to_tokens(predicted_tokens_ids) response['completion']['text'] = self.tokenizer.tokens_to_text(predicted_tokens_dec) response['completion']['tokens'] = list(zip(predicted_tokens_ids, predicted_tokens_dec, log_probs)) self.unfreeze() return response
def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None): app_state = AppState() global_batch_per_gpu = tokens_enc.size(0) num_micro_batches_before_decode = get_num_microbatches() # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc". # TODO: reconfigure back to how things were before decode? _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu, # Make sure that there is no "grad acc" while decoding. data_parallel_size=parallel_state.get_data_parallel_world_size(), ) predicted_tokens_dec = ( torch.LongTensor([self.tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device) ) encoder_seq_length = tokens_enc.size(1) tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size] assert predicted_tokens_dec.size(0) == global_batch_per_gpu for i in range(num_tokens_to_generate): # No microbatches in decoding. Just the global batch. decoder_seq_length = predicted_tokens_dec.size(1) dec_mask = predicted_tokens_dec != self.tokenizer.pad_id if encoder_input is not None: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input] else: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: output_tensor = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) else: output_tensor = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) # get output tensor if parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor[0]['logits'] output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor) log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1) predicted_tokens_dec = torch.cat( [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1 ) else: log_probs = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype ).cuda() predicted_tokens_dec = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1), dtype=predicted_tokens_dec.dtype, ).cuda() if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # Broadcast from the last pipeline stage to all other model-parallel ranks. torch.distributed.broadcast( predicted_tokens_dec, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) torch.distributed.broadcast( log_probs, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) # Reset microbatch calculator to what it was before decoding. _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return predicted_tokens_dec, log_probs
def inference_step(self, batch, batch_idx, mode, dataloader_idx=0): batch_has_lang_information = len(batch[0]) == 7 # XNLI Batches have language information that need to be removed before calling the parent validation step. if batch_has_lang_information: processed_batch = [] for micro_batch in batch: micro_batch = { k: v for k, v in micro_batch.items() if k != 'lang' } processed_batch.append(micro_batch) else: processed_batch = batch micro_batch_size = processed_batch[0]['text_enc'].size(0) # This should happen only on the last batch of the dataset. if micro_batch_size != self.cfg.data.validation_ds.micro_batch_size: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size * parallel_state.get_data_parallel_world_size() * get_num_microbatches(), micro_batch_size=micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) # Call parent validation step to get the loss. loss = super().validation_step(processed_batch, batch_idx) # Remainder of the code is to run the decoding loop, and compute accuracies. if batch_has_lang_information: tokens_enc, _, _, labels, enc_mask, _, langs = self.process_global_batch( batch) else: tokens_enc, _, _, labels, enc_mask, _ = self.process_global_batch( batch) predicted_token_ids, _ = self.decode(tokens_enc=tokens_enc, enc_mask=enc_mask, num_tokens_to_generate=30) preds_text, labels_text = self.preds_and_labels_to_text( predicted_token_ids, labels) if not batch_has_lang_information: if (mode == 'validation' and hasattr(self.cfg.data.validation_ds, "names") and isinstance(self.cfg.data.validation_ds.names, ListConfig)): categories = [ self.cfg.data.validation_ds.names[dataloader_idx] ] * len(preds_text) elif (mode == 'test' and hasattr(self.cfg.data.test_ds, "names") and isinstance(self.cfg.data.test_ds.names, ListConfig)): categories = [self.cfg.data.test_ds.names[dataloader_idx] ] * len(preds_text) else: categories = [None] * len(preds_text) else: categories = langs metric = self.val_metric if mode == 'validation' else self.test_metric assert len(categories) == len(preds_text) == len(labels_text) for _, (pred, label, category) in enumerate(zip(preds_text, labels_text, categories)): _ = metric(pred, label, category) return { 'loss': loss, 'preds': preds_text, 'labels': labels_text, 'categories': categories }
def sample_sequence_batch( model, context_tokens, context_lengths, task_ids, attention_mask, position_ids, tokens_to_generate, all_probs=False, type_ids=None, temperature=None, extra={}, ): # Importing here to avoid circular import errors from nemo.collections.nlp.models.language_modeling import MegatronGPTPromptLearningModel app_state = AppState() micro_batch_size = context_tokens.shape[0] _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size, micro_batch_size=micro_batch_size, data_parallel_size=1, ) tokenizer = model.tokenizer model.eval() with torch.no_grad(): context_length = context_lengths.min().item() # added eos_id to support the function generate_samples_eval that passes # eos_id as an argument and needs termination when that id id found. eod_id = tokenizer.eos_id counter = 0 batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens output_logits = None all_generated_indices = None # used to track all generated indices # Generate enough tokens for the longest sequence maxlen = tokens_to_generate + context_lengths.max().item() if maxlen > model.cfg.encoder_seq_length + 1: maxlen = model.cfg.encoder_seq_length + 1 lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length < maxlen: # types2use = None if counter == 0: # Allocate memory for the entire context. set_inference_key_value_memory = True tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, :context_length] else: # Set this to false so the memory is not reallocated. set_inference_key_value_memory = False tokens2use = tokens[:, context_length - 1].view(batch_size, -1) positions2use = position_ids[:, context_length - 1].view( batch_size, -1) # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, context_length - 1].view(batch_size, -1) attention_mask_repeat = torch.concat( [attention_mask for _ in range(micro_batch_size)]) setkey_value_array = torch.tensor( [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device()) len_array = torch.tensor([maxlen] * micro_batch_size, device=torch.cuda.current_device()) # Only prompt learning models will have a prompt table, and require task ids if isinstance(model, MegatronGPTPromptLearningModel): batch = [ tokens2use, attention_mask_repeat, positions2use, task_ids, setkey_value_array, len_array ] tensor_shape = [ tokens2use.shape[1], micro_batch_size, model.frozen_model.cfg.hidden_size ] else: batch = [ tokens2use, attention_mask_repeat, positions2use, setkey_value_array, len_array ] tensor_shape = [ tokens2use.shape[1], micro_batch_size, model.cfg.hidden_size ] output = forward_step(model, batch, tensor_shape) if parallel_state.is_pipeline_last_stage(): output = output[0]['logits'].float() output = tensor_parallel.gather_from_tensor_model_parallel_region( output) assert output is not None output = output.float() logits = output[:, -1].view(batch_size, -1).contiguous() # make sure it will generate at least min_length min_length = extra.get('min_tokens_to_generate', 0) if min_length > 0: within_min_length = (context_length - context_lengths) < min_length logits[within_min_length, eod_id] = -float('Inf') # make sure it won't sample outside the vocab_size range logits[:, tokenizer.vocab_size:] = -float('Inf') if extra.get('greedy', False): prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= temperature # handle repetition penality logits = repetition_penalty( logits, extra.get('repetition_penalty', 1.2), all_generated_indices) logits = top_k_logits(logits, top_k=extra.get('top_k', 0), top_p=extra.get('top_p', 0.9)) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length # Clamp the predicted out of vocabulary tokens prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) new_tokens = switch(tokens[:, context_length].view(-1), prev, started) # Replace sampled tokens w/ done token if EOD has already been sampled new_tokens = switch(new_tokens, eod_id, is_done) # Replace special soft prompt token ids with unk token ids if isinstance(model, MegatronGPTPromptLearningModel): pseudo_token_ids_start = model.pseudo_token_ids_start new_tokens[(new_tokens >= pseudo_token_ids_start)] = tokenizer.unk_id tokens[:, :context_length][( tokens[:, :context_length] >= pseudo_token_ids_start)] = tokenizer.unk_id # Insert either new predicted or next prompt token tokens[:, context_length] = new_tokens if output_logits is None: output = F.log_softmax(output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, 1:context_length + 1], 2) output_logits = torch.gather(output, 2, indices).squeeze(2) all_generated_indices = indices[:, :, 0] if all_probs: full_logits = output else: output = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2) new_output_logits = torch.gather(output, 2, indices).squeeze(2) # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat( [output_logits, new_output_logits], 1) all_generated_indices = torch.cat( [all_generated_indices, indices[:, :, 0]], 1) if all_probs: full_logits = torch.cat([full_logits, output], 1) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) done_token = (prev == eod_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) if all_probs: yield tokens, lengths, output_logits, full_logits else: yield tokens, lengths, output_logits, None else: if parallel_state.is_pipeline_first_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank( ) group = parallel_state.get_embedding_group() new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None, None, None else: yield None, None, None, None done = torch.cuda.ByteTensor([0]) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) context_length += 1 counter += 1 if done: break
def forward_backward_func_template( args, name: str, forward_backward_func, pipeline_model_parallel_size: int, forward_only: bool, dtype: torch.dtype, grad_scaler: Optional[GradScaler], deallocate_pipeline_outputs: bool, data_parallel_size: int, ) -> None: print_separator( f"{name}, {dtype}, use grad_scaler: {grad_scaler is not None}, " f"deallocate_pipeline_outputs: {deallocate_pipeline_outputs}, " f"pipeline parallel size: {pipeline_model_parallel_size}, " f"data parallel size: {data_parallel_size}") virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None if name == "no_pipelining": # note (mkozuki): `forward_backward_no_pipelining` is **NOT** compatible with # pipeline_model_parallel_size>1. So use pipeline_model_parallel_size as # tensor_model_parallel_size and set pipeline_model_parallel_size to 1. parallel_state.initialize_model_parallel(1, 1, None) _reconfigure_microbatch_calculator( args.rank, args.rampup_batch_size, args.global_batch_size, args.micro_batch_size, parallel_state.get_data_parallel_world_size(), ) else: # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is necessary to enable interleaving scheduling # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and # used ubiquitously but this test uses custom model so it's safe to abuse. parallel_state.initialize_model_parallel( data_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size) _reconfigure_microbatch_calculator( args.rank, args.rampup_batch_size, args.global_batch_size, args.micro_batch_size, parallel_state.get_data_parallel_world_size(), ) if virtual_pipeline_model_parallel_size is not None: # Check the experimental warning message get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size) pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size( ) model = build_model( model_provider_func, wrap_with_ddp=True, virtual_pipeline_model_parallel_size= virtual_pipeline_model_parallel_size, hidden_size=hidden_size, ) assert isinstance(model, list) assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size) _param_groups = _get_params_for_weight_decay_optimization(model) torch.optim.Adam(_param_groups, lr=1e-4) tensor_shape = [ batch_size // parallel_state.get_data_parallel_world_size(), hidden_size, hidden_size ] batch = (torch.randn(tensor_shape).cuda(), ) tensor_shape[0] = micro_batch_size update_num_microbatches(0) forward_backward_func( fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape, dtype=dtype, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs, ) if not forward_only: for m in model: for p in m.parameters(): if p.grad is None: raise RuntimeError("grad not found") torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def eval_step(self, batch, batch_idx, dataloader_idx): # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1. batch = [[x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch] for microbatch in batch] batch = self.process_global_batch_for_tarred_datasets(batch) app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=batch['text_enc'].size(0) * parallel_state.get_data_parallel_world_size(), micro_batch_size=batch['text_enc'].size(0), data_parallel_size=parallel_state.get_data_parallel_world_size(), ) # This returns the averaged loss across data-parallel groups. reduced_loss = super().validation_step(batch, batch_idx) tokens_enc, labels, enc_mask = batch['text_enc'], batch[ 'labels'], batch['enc_mask'] predicted_tokens_ids, _ = self.decode( tokens_enc, enc_mask, tokens_enc.size(1) + self._cfg. max_generation_delta, # Generate up to src-length + max generation delta. TODO: Implement better stopping when everything hits <EOS>. tokenizer=self.decoder_tokenizer, ) # Post-process the translations and inputs to log. # Convert ids to lists. preds = predicted_tokens_ids.cpu().numpy().tolist() labels = labels.cpu().numpy().tolist() encoder_inputs = tokens_enc.cpu().numpy().tolist() # Filter out the special tokens and de-tokenize. inputs = [] translations = [] ground_truths = [] for _, (pred, label, input) in enumerate(zip(preds, labels, encoder_inputs)): if self.decoder_tokenizer.eos_id in pred: idx = pred.index(self.decoder_tokenizer.eos_id) pred = pred[:idx] # Legacy sentencepiece detokenization still preserves special tokens which messes up exact string match. if hasattr(self.decoder_tokenizer, 'special_token_to_id'): pred = [ id for id in pred if id not in self.decoder_tokenizer.special_token_to_id.values() ] label = [ id for id in label if id not in self.decoder_tokenizer.special_token_to_id.values() ] if hasattr(self.encoder_tokenizer, 'special_token_to_id'): input = [ id for id in input if id not in self.encoder_tokenizer.special_token_to_id.values() ] pred = self.decoder_tokenizer.ids_to_text(pred) label = self.decoder_tokenizer.ids_to_text(label) input = self.encoder_tokenizer.ids_to_text(input) translations.append(pred) ground_truths.append(label) inputs.append(input) if self.multilingual: self.source_processor = self.source_processor_list[dataloader_idx] self.target_processor = self.target_processor_list[dataloader_idx] # De-tokenize inputs, translations and ground truths. if self.target_processor is not None: ground_truths = [ self.target_processor.detokenize(tgt.split(' ')) for tgt in ground_truths ] translations = [ self.target_processor.detokenize(translation.split(' ')) for translation in translations ] if self.source_processor is not None: inputs = [ self.source_processor.detokenize(src.split(' ')) for src in inputs ] return { 'inputs': inputs, 'translations': translations, 'ground_truths': ground_truths, 'loss': reduced_loss, }
def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None, tokenizer=None): # Check whether the DDP is initialized. This is needed when running inference outside of training loop. if parallel_state.is_unitialized(): def dummy(): return if self.trainer.strategy.launcher is not None: self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer) self.trainer.strategy.setup_environment() # Reconfigure microbatch sizes here because on model restore, this will contain the micro/global batch configuration used while training. _reconfigure_microbatch_calculator( rank=0, # This doesn't matter since it is only used for logging rampup_batch_size=None, global_batch_size=1, micro_batch_size=1, # Make sure that there is no "grad acc" while decoding. data_parallel_size=1, # We check above to make sure that dataparallel size is always 1 at inference. ) # If classes that inherit from this class are using a different tokenizer, tokenizer = self.tokenizer if tokenizer is None else tokenizer app_state = AppState() global_batch_per_gpu = tokens_enc.size(0) num_micro_batches_before_decode = get_num_microbatches() # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc". # TODO: reconfigure back to how things were before decode? _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu, # Make sure that there is no "grad acc" while decoding. data_parallel_size=parallel_state.get_data_parallel_world_size(), ) predicted_tokens_dec = ( torch.LongTensor([tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device) ) encoder_seq_length = tokens_enc.size(1) tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size] assert predicted_tokens_dec.size(0) == global_batch_per_gpu for i in range(num_tokens_to_generate): # No microbatches in decoding. Just the global batch. decoder_seq_length = predicted_tokens_dec.size(1) dec_mask = predicted_tokens_dec != tokenizer.pad_id if encoder_input is not None: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input] else: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: output_tensor = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) else: output_tensor = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) # get output tensor if parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor[0]['logits'] output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor) log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1) predicted_tokens_dec = torch.cat( [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1 ) else: log_probs = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype ).cuda() predicted_tokens_dec = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1), dtype=predicted_tokens_dec.dtype, ).cuda() if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # Broadcast from the last pipeline stage to all other model-parallel ranks. torch.distributed.broadcast( predicted_tokens_dec, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) torch.distributed.broadcast( log_probs, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) # Reset microbatch calculator to what it was before decoding. _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return predicted_tokens_dec, log_probs
def run_interleaved_with_dynamic_batch_size( pipeline_model_parallel_size: int, forward_only: bool, BatchSamplerCls, ) -> None: args = global_vars.get_args() _reconfigure_microbatch_calculator( args.rank, args.rampup_batch_size, args.global_batch_size, args.micro_batch_size, 1, # args.data_parallel_size, ) virtual_pipeline_model_parallel_size = 2 # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is a requisite for the interleaving scheduling # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and # used ubiquitously but this test uses custom model so it's safe to abuse. parallel_state.initialize_model_parallel( 1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size) pipeline_model_parallel_size = ( parallel_state.get_pipeline_model_parallel_world_size()) print_separator( f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}" ) model = build_model( model_provider_func, wrap_with_ddp=True, virtual_pipeline_model_parallel_size= virtual_pipeline_model_parallel_size, hidden_size=HIDDEN_SIZE, ) assert isinstance(model, list) assert len(model) == virtual_pipeline_model_parallel_size optimizer = torch.optim.Adam( _get_params_for_weight_decay_optimization(model)) initial_local_minibatch_size = get_num_microbatches() * micro_batch_size dataset = Dataset(NUM_SAMPLES) data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=BatchSamplerCls( NUM_SAMPLES, 0, initial_local_minibatch_size, parallel_state.get_data_parallel_rank(), parallel_state.get_data_parallel_world_size(), ), ) data_iter = iter(data_loader) def get_num_samples(batch): if isinstance(batch, torch.Tensor): return len(batch) assert isinstance(batch, (list, tuple)) return [get_num_samples(b) for b in batch] tensor_shape = [micro_batch_size, HIDDEN_SIZE, HIDDEN_SIZE] consumed_samples = 0 for i in range(NUM_ITERATIONS): update_num_microbatches(consumed_samples, consistency_check=False) local_batch_size = get_num_microbatches() * micro_batch_size data_iter._index_sampler.local_minibatch_size = local_batch_size local_mini_batch = next(data_iter) _logger.info(f"iter: {i} / {NUM_ITERATIONS} " f"local batchsize: {get_num_samples(local_mini_batch)} " f"consumed_samples: {consumed_samples} / {NUM_SAMPLES}") _forward_backward_pipelining_with_interleaving( fwd_step_func, local_mini_batch, model, forward_only=forward_only, tensor_shape=tensor_shape, ) consumed_samples += (parallel_state.get_data_parallel_world_size() * get_num_microbatches() * micro_batch_size) if not forward_only: for m in model: for p in m.parameters(): if p.grad is None: raise RuntimeError("grad not found") else: optimizer.zero_grad(set_to_none=True) torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)