def count_zeros_fp32(parameters): if isinstance(parameters, torch.Tensor): parameters = [parameters] # Filter parameters based on: # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism total_num_zeros = 0.0 for param in parameters: grad_not_none = param.grad is not None is_not_shared = param_is_not_shared(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) if grad_not_none and is_not_shared and is_not_tp_duplicate: grad = param.grad.detach() num_zeros = grad.numel() - torch.count_nonzero(grad) total_num_zeros = num_zeros + total_num_zeros # Sum across all model-parallel GPUs. torch.distributed.all_reduce( total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group()) total_num_zeros = total_num_zeros.item() return total_num_zeros
def calc_params_l2_norm(model: torch.nn.Module, bf16: bool): """Calculate l2 norm of parameters """ # args = get_args() if not isinstance(model, list): model = [model] # Remove duplicate params. params_data = [] for model_ in model: for param in model_.parameters(): is_not_shared = param_is_not_shared(param) is_not_tp_duplicate = parallel_state.param_is_not_tensor_parallel_duplicate( param) if is_not_shared and is_not_tp_duplicate: if bf16: params_data.append(param.data.float()) else: params_data.append(param.data) # Calculate norm dummy_overflow_buf = torch.cuda.IntTensor([0]) norm, _ = multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, [params_data], False # no per-parameter norm ) norm_2 = norm * norm # Sum across all model-parallel GPUs. torch.distributed.all_reduce( norm_2, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group()) return norm_2.item()**0.5
def restore_weights(self, restore_path: str): """Restores module/model's weights. For model parallel checkpoints the directory structure should be restore_path/mp_rank_0X/model_optim_rng.pt Args: restore_path (str): restore_path should a file or a directory if using model parallel """ self._restore_path = restore_path if os.path.isfile(restore_path): self._load_checkpoint(restore_path) elif os.path.isdir(restore_path): # need model parallel groups to restore model parallel checkpoints if model_parallel_is_initialized(): model_parallel_rank = torch.distributed.get_rank( group=get_model_parallel_group()) mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt' self._load_checkpoint(mp_restore_path) else: logging.info( f'torch.distributed not initialized yet. Will not restore model parallel checkpoint' ) else: logging.error( f'restore_path: {restore_path} must be a file or directory.')
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): retval = None found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())]) # Update across all model parallel instances. torch.distributed.all_reduce( found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group() ) if found_inf.item() == 0: retval = optimizer.step(*args, **kwargs) return retval
def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None): app_state = AppState() global_batch_per_gpu = tokens_enc.size(0) num_micro_batches_before_decode = get_num_microbatches() # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc". # TODO: reconfigure back to how things were before decode? _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu, # Make sure that there is no "grad acc" while decoding. data_parallel_size=parallel_state.get_data_parallel_world_size(), ) predicted_tokens_dec = ( torch.LongTensor([self.tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device) ) encoder_seq_length = tokens_enc.size(1) tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size] assert predicted_tokens_dec.size(0) == global_batch_per_gpu for i in range(num_tokens_to_generate): # No microbatches in decoding. Just the global batch. decoder_seq_length = predicted_tokens_dec.size(1) dec_mask = predicted_tokens_dec != self.tokenizer.pad_id if encoder_input is not None: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input] else: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: output_tensor = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) else: output_tensor = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) # get output tensor if parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor[0]['logits'] output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor) log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1) predicted_tokens_dec = torch.cat( [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1 ) else: log_probs = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype ).cuda() predicted_tokens_dec = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1), dtype=predicted_tokens_dec.dtype, ).cuda() if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # Broadcast from the last pipeline stage to all other model-parallel ranks. torch.distributed.broadcast( predicted_tokens_dec, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) torch.distributed.broadcast( log_probs, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) # Reset microbatch calculator to what it was before decoding. _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return predicted_tokens_dec, log_probs
def update(self, new_scale=None): """ Updates the scale factor. If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, the scale is multiplied by ``growth_factor`` to increase it. Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not used directly, it's used to fill GradScaler's internal scale tensor. So if ``new_scale`` was a tensor, later in-place changes to that tensor will not further affect the scale GradScaler uses internally.) Args: new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. .. warning:: :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has been invoked for all optimizers used this iteration. """ if not self._enabled: return _scale, _growth_tracker = self._check_scale_growth_tracker("update") if new_scale is not None: # Accept a new user-defined scale. if isinstance(new_scale, float): self._scale.fill_(new_scale) # type: ignore[union-attr] else: reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." assert isinstance(new_scale, torch.cuda.FloatTensor ), reason # type: ignore[attr-defined] assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason self._scale.copy_(new_scale) # type: ignore[union-attr] else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. found_infs = [ found_inf.to(device=_scale.device, non_blocking=True) for state in self._per_optimizer_states.values() for found_inf in state["found_inf_per_device"].values() ] assert len( found_infs) > 0, "No inf checks were recorded prior to update." found_inf_combined = found_infs[0] # Update across all model parallel instances. torch.distributed.all_reduce( found_inf_combined, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()) if len(found_infs) > 1: for i in range(1, len(found_infs)): found_inf = found_infs[i] # Update across all model parallel instances. torch.distributed.all_reduce( found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()) found_inf_combined += found_inf torch._amp_update_scale_( _scale, _growth_tracker, found_inf_combined, self._growth_factor, self._backoff_factor, self._growth_interval, ) # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict( torch.cuda.amp.grad_scaler._refresh_per_optimizer_state)
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): """Clips gradient norm of an iterable of parameters whose gradients are in fp32. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Note that the gradients are modified in place. Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] # Filter parameters based on: # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism grads = [] grads_for_norm = [] for param in parameters: grad_not_none = param.grad is not None is_not_shared = param_is_not_shared(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) grad = param.grad.detach() if grad_not_none: # Make sure the grads are in fp32 assert isinstance(param.grad, torch.cuda.FloatTensor) grads.append(grad) if grad_not_none and is_not_shared and is_not_tp_duplicate: grads_for_norm.append(grad) # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) total_norm = 0.0 # Calculate norm. if norm_type == inf: total_norm = max(grad.abs().max() for grad in grads_for_norm) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce( total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: if norm_type == 2.0: dummy_overflow_buf = torch.cuda.IntTensor([0]) # Use apex's multi-tensor applier for efficiency reasons. # Multi-tensor applier takes a function and a list of list # and performs the operation on that list all in one kernel. grad_norm, _ = multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_for_norm], False # no per-parameter norm ) # Since we will be summing across data parallel groups, # we need the pow(norm-type). total_norm = grad_norm**norm_type else: for grad in grads_for_norm: grad_norm = torch.norm(grad, norm_type) total_norm += grad_norm**norm_type # Sum across all model-parallel GPUs. torch.distributed.all_reduce( total_norm, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group()) total_norm = total_norm.item()**(1.0 / norm_type) # Scale. clip_coeff = max_norm / (total_norm + 1.0e-6) if clip_coeff < 1.0: dummy_overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) return total_norm
def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None, tokenizer=None): # Check whether the DDP is initialized. This is needed when running inference outside of training loop. if parallel_state.is_unitialized(): def dummy(): return if self.trainer.strategy.launcher is not None: self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer) self.trainer.strategy.setup_environment() # Reconfigure microbatch sizes here because on model restore, this will contain the micro/global batch configuration used while training. _reconfigure_microbatch_calculator( rank=0, # This doesn't matter since it is only used for logging rampup_batch_size=None, global_batch_size=1, micro_batch_size=1, # Make sure that there is no "grad acc" while decoding. data_parallel_size=1, # We check above to make sure that dataparallel size is always 1 at inference. ) # If classes that inherit from this class are using a different tokenizer, tokenizer = self.tokenizer if tokenizer is None else tokenizer app_state = AppState() global_batch_per_gpu = tokens_enc.size(0) num_micro_batches_before_decode = get_num_microbatches() # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc". # TODO: reconfigure back to how things were before decode? _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu, # Make sure that there is no "grad acc" while decoding. data_parallel_size=parallel_state.get_data_parallel_world_size(), ) predicted_tokens_dec = ( torch.LongTensor([tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device) ) encoder_seq_length = tokens_enc.size(1) tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size] assert predicted_tokens_dec.size(0) == global_batch_per_gpu for i in range(num_tokens_to_generate): # No microbatches in decoding. Just the global batch. decoder_seq_length = predicted_tokens_dec.size(1) dec_mask = predicted_tokens_dec != tokenizer.pad_id if encoder_input is not None: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input] else: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: output_tensor = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) else: output_tensor = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) # get output tensor if parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor[0]['logits'] output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor) log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1) predicted_tokens_dec = torch.cat( [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1 ) else: log_probs = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype ).cuda() predicted_tokens_dec = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1), dtype=predicted_tokens_dec.dtype, ).cuda() if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # Broadcast from the last pipeline stage to all other model-parallel ranks. torch.distributed.broadcast( predicted_tokens_dec, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) torch.distributed.broadcast( log_probs, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) # Reset microbatch calculator to what it was before decoding. _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return predicted_tokens_dec, log_probs
def update(self, new_scale=None): """ Updates to native grad scaler update function. 1. Check inf across model-parallel ranks. 2. Update hysteresis tracker. 3. Apply hysteresis to grad scale update. """ if not self._enabled: return _scale, _growth_tracker = self._check_scale_growth_tracker("update") if new_scale is not None: # Accept a new user-defined scale. if isinstance(new_scale, float): self._scale.fill_(new_scale) # type: ignore[union-attr] else: reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." assert isinstance(new_scale, torch.cuda.FloatTensor ), reason # type: ignore[attr-defined] assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason self._scale.copy_(new_scale) # type: ignore[union-attr] else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. found_infs = [ found_inf.to(device=_scale.device, non_blocking=True) for state in self._per_optimizer_states.values() for found_inf in state["found_inf_per_device"].values() ] assert len( found_infs) > 0, "No inf checks were recorded prior to update." found_inf_combined = found_infs[0] # Update across all model parallel instances. torch.distributed.all_reduce( found_inf_combined, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()) if len(found_infs) > 1: for i in range(1, len(found_infs)): found_inf = found_infs[i] # Update across all model parallel instances. torch.distributed.all_reduce( found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()) found_inf_combined += found_inf if found_inf_combined > 0: self._hysteresis_tracker -= 1 if self._hysteresis_tracker <= 0: # When hysteresis becomes zero, follow the native grad scale update rule. # Increase scale and reset growth tracker torch._amp_update_scale_( _scale, _growth_tracker, found_inf_combined, self._growth_factor, self._backoff_factor, self._growth_interval, ) else: # Only reset the growth tracker when hysteresis is larger than zero _growth_tracker.fill_(0.0) else: # When no inf found, follow the native grad scale update rule. # Increment growth_tracker, update scale when growth tracker reaches the interval, and # reset the hysteresis tracker. torch._amp_update_scale_( _scale, _growth_tracker, found_inf_combined, self._growth_factor, self._backoff_factor, self._growth_interval, ) self._hysteresis_tracker = self.hysteresis # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict( torch.cuda.amp.grad_scaler._refresh_per_optimizer_state)